Skip to main content

rustic_ai/
model.rs

1use std::pin::Pin;
2
3use async_trait::async_trait;
4use futures::stream::Stream;
5use serde_json::{Map, Value};
6use thiserror::Error;
7
8use crate::messages::{ModelMessage, ModelResponse, ToolCallPart};
9use crate::tools::ToolDefinition;
10use crate::usage::RequestUsage;
11
12pub type ModelSettings = Map<String, Value>;
13
14#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
15pub enum OutputMode {
16    #[default]
17    Text,
18    JsonSchema,
19}
20
21#[derive(Clone, Debug)]
22pub struct ModelRequestParameters {
23    pub function_tools: Vec<ToolDefinition>,
24    pub output_schema: Option<Value>,
25    pub output_mode: OutputMode,
26    pub allow_text_output: bool,
27}
28
29impl ModelRequestParameters {
30    pub fn new(function_tools: Vec<ToolDefinition>) -> Self {
31        Self {
32            function_tools,
33            output_schema: None,
34            output_mode: OutputMode::Text,
35            allow_text_output: true,
36        }
37    }
38
39    pub fn with_output_schema(mut self, schema: Value) -> Self {
40        self.output_schema = Some(schema);
41        self.output_mode = OutputMode::JsonSchema;
42        self.allow_text_output = false;
43        self
44    }
45}
46
47impl Default for ModelRequestParameters {
48    fn default() -> Self {
49        Self {
50            function_tools: Vec::new(),
51            output_schema: None,
52            output_mode: OutputMode::Text,
53            allow_text_output: true,
54        }
55    }
56}
57
58#[derive(Clone, Debug)]
59pub struct StreamChunk {
60    pub text_delta: Option<String>,
61    pub tool_call: Option<ToolCallPart>,
62    pub finish_reason: Option<String>,
63    pub usage: Option<RequestUsage>,
64}
65
66pub type ModelStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, ModelError>> + Send>>;
67
68#[async_trait]
69pub trait Model: Send + Sync {
70    fn name(&self) -> &str;
71
72    async fn request(
73        &self,
74        messages: &[ModelMessage],
75        settings: Option<&ModelSettings>,
76        params: &ModelRequestParameters,
77    ) -> Result<ModelResponse, ModelError>;
78
79    async fn count_tokens(
80        &self,
81        _messages: &[ModelMessage],
82        _settings: Option<&ModelSettings>,
83        _params: &ModelRequestParameters,
84    ) -> Result<RequestUsage, ModelError> {
85        Err(ModelError::Unsupported(
86            "token counting not supported".to_string(),
87        ))
88    }
89
90    async fn request_stream(
91        &self,
92        _messages: &[ModelMessage],
93        _settings: Option<&ModelSettings>,
94        _params: &ModelRequestParameters,
95    ) -> Result<ModelStream, ModelError> {
96        Err(ModelError::Unsupported(
97            "streaming not supported".to_string(),
98        ))
99    }
100}
101
102#[derive(Debug, Error)]
103pub enum ModelError {
104    #[error("provider error: {0}")]
105    Provider(String),
106    #[error("http error status: {status}")]
107    HttpStatus { status: u16 },
108    #[error("transport error: {0}")]
109    Transport(String),
110    #[error("timeout error")]
111    Timeout,
112    #[error("unsupported: {0}")]
113    Unsupported(String),
114    #[error("serialization error: {0}")]
115    Serialization(#[from] serde_json::Error),
116}