1use std::pin::Pin;
2
3use async_trait::async_trait;
4use futures::stream::Stream;
5use serde_json::{Map, Value};
6use thiserror::Error;
7
8use crate::messages::{ModelMessage, ModelResponse, ToolCallPart};
9use crate::tools::ToolDefinition;
10use crate::usage::RequestUsage;
11
12pub type ModelSettings = Map<String, Value>;
13
14#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
15pub enum OutputMode {
16 #[default]
17 Text,
18 JsonSchema,
19}
20
21#[derive(Clone, Debug)]
22pub struct ModelRequestParameters {
23 pub function_tools: Vec<ToolDefinition>,
24 pub output_schema: Option<Value>,
25 pub output_mode: OutputMode,
26 pub allow_text_output: bool,
27}
28
29impl ModelRequestParameters {
30 pub fn new(function_tools: Vec<ToolDefinition>) -> Self {
31 Self {
32 function_tools,
33 output_schema: None,
34 output_mode: OutputMode::Text,
35 allow_text_output: true,
36 }
37 }
38
39 pub fn with_output_schema(mut self, schema: Value) -> Self {
40 self.output_schema = Some(schema);
41 self.output_mode = OutputMode::JsonSchema;
42 self.allow_text_output = false;
43 self
44 }
45}
46
47impl Default for ModelRequestParameters {
48 fn default() -> Self {
49 Self {
50 function_tools: Vec::new(),
51 output_schema: None,
52 output_mode: OutputMode::Text,
53 allow_text_output: true,
54 }
55 }
56}
57
58#[derive(Clone, Debug)]
59pub struct StreamChunk {
60 pub text_delta: Option<String>,
61 pub tool_call: Option<ToolCallPart>,
62 pub finish_reason: Option<String>,
63 pub usage: Option<RequestUsage>,
64}
65
66pub type ModelStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, ModelError>> + Send>>;
67
68#[async_trait]
69pub trait Model: Send + Sync {
70 fn name(&self) -> &str;
71
72 async fn request(
73 &self,
74 messages: &[ModelMessage],
75 settings: Option<&ModelSettings>,
76 params: &ModelRequestParameters,
77 ) -> Result<ModelResponse, ModelError>;
78
79 async fn count_tokens(
80 &self,
81 _messages: &[ModelMessage],
82 _settings: Option<&ModelSettings>,
83 _params: &ModelRequestParameters,
84 ) -> Result<RequestUsage, ModelError> {
85 Err(ModelError::Unsupported(
86 "token counting not supported".to_string(),
87 ))
88 }
89
90 async fn request_stream(
91 &self,
92 _messages: &[ModelMessage],
93 _settings: Option<&ModelSettings>,
94 _params: &ModelRequestParameters,
95 ) -> Result<ModelStream, ModelError> {
96 Err(ModelError::Unsupported(
97 "streaming not supported".to_string(),
98 ))
99 }
100}
101
102#[derive(Debug, Error)]
103pub enum ModelError {
104 #[error("provider error: {0}")]
105 Provider(String),
106 #[error("http error status: {status}")]
107 HttpStatus { status: u16 },
108 #[error("transport error: {0}")]
109 Transport(String),
110 #[error("timeout error")]
111 Timeout,
112 #[error("unsupported: {0}")]
113 Unsupported(String),
114 #[error("serialization error: {0}")]
115 Serialization(#[from] serde_json::Error),
116}