1use async_trait::async_trait;
4use schemars::JsonSchema;
5use serde::{Serialize, de::DeserializeOwned};
6use serde_json::Value;
7
8use crate::error::{LlmError, LlmResult};
9use crate::schema::generate_json_schema;
10use crate::types::{GenerationOptions, GenerationResponse, Message, MessageRole};
11
12#[async_trait]
17pub trait Llm: Send + Sync {
18 async fn generate(
20 &self,
21 messages: Vec<Message>,
22 options: Option<GenerationOptions>,
23 ) -> LlmResult<GenerationResponse>;
24
25 async fn create_structured_output_raw(
30 &self,
31 text_input: &str,
32 system_prompt: &str,
33 json_schema: &Value,
34 options: Option<GenerationOptions>,
35 ) -> LlmResult<Value> {
36 let messages = vec![
37 Message {
38 role: MessageRole::System,
39 content: system_prompt.to_string(),
40 },
41 Message {
42 role: MessageRole::User,
43 content: text_input.to_string(),
44 },
45 ];
46 self.create_structured_output_with_messages_raw(messages, json_schema, options)
47 .await
48 }
49
50 async fn create_structured_output_with_messages_raw(
55 &self,
56 messages: Vec<Message>,
57 json_schema: &Value,
58 options: Option<GenerationOptions>,
59 ) -> LlmResult<Value>;
60
61 fn model(&self) -> &str;
63
64 fn supports_streaming(&self) -> bool {
66 false
67 }
68
69 fn supports_function_calling(&self) -> bool {
71 false
72 }
73
74 fn max_context_length(&self) -> u32 {
76 4096
77 }
78
79 async fn transcribe_image(
92 &self,
93 image_bytes: &[u8],
94 mime_type: &str,
95 options: Option<GenerationOptions>,
96 ) -> LlmResult<String> {
97 let _ = (image_bytes, mime_type, options);
98 Err(LlmError::FeatureNotSupported(format!(
99 "Vision is not supported by model: {}",
100 self.model()
101 )))
102 }
103
104 fn supports_vision(&self) -> bool {
112 false
113 }
114}
115
116#[async_trait]
119pub trait LlmExt: Llm {
120 async fn create_structured_output<T>(
125 &self,
126 text_input: &str,
127 system_prompt: &str,
128 options: Option<GenerationOptions>,
129 ) -> LlmResult<T>
130 where
131 T: Serialize + DeserializeOwned + JsonSchema + Send,
132 {
133 let schema = generate_json_schema::<T>();
134 let value = self
135 .create_structured_output_raw(text_input, system_prompt, &schema, options)
136 .await?;
137 serde_json::from_value(value).map_err(|e| {
138 LlmError::DeserializationError(format!("Failed to deserialize structured output: {e}"))
139 })
140 }
141
142 async fn create_structured_output_with_messages<T>(
147 &self,
148 messages: Vec<Message>,
149 options: Option<GenerationOptions>,
150 ) -> LlmResult<T>
151 where
152 T: Serialize + DeserializeOwned + JsonSchema + Send,
153 {
154 let schema = generate_json_schema::<T>();
155 let value = self
156 .create_structured_output_with_messages_raw(messages, &schema, options)
157 .await?;
158 serde_json::from_value(value).map_err(|e| {
159 LlmError::DeserializationError(format!("Failed to deserialize structured output: {e}"))
160 })
161 }
162}
163
164impl<T: Llm + ?Sized> LlmExt for T {}
165
166#[cfg(test)]
167mod tests {
168 #![allow(
169 clippy::unwrap_used,
170 clippy::expect_used,
171 reason = "test code — panics are acceptable"
172 )]
173 use super::*;
174
175 struct DummyLlm;
176
177 #[async_trait]
178 impl Llm for DummyLlm {
179 async fn generate(
180 &self,
181 _: Vec<Message>,
182 _: Option<GenerationOptions>,
183 ) -> LlmResult<GenerationResponse> {
184 unimplemented!()
185 }
186 async fn create_structured_output_with_messages_raw(
187 &self,
188 _: Vec<Message>,
189 _: &Value,
190 _: Option<GenerationOptions>,
191 ) -> LlmResult<Value> {
192 unimplemented!()
193 }
194 fn model(&self) -> &str {
195 "dummy"
196 }
197 }
198
199 #[tokio::test]
200 async fn default_transcribe_image_returns_feature_not_supported() {
201 let llm = DummyLlm;
202 let result = llm.transcribe_image(b"fake-png", "image/png", None).await;
203 assert!(result.is_err());
204 let err = result.unwrap_err();
205 assert!(
206 matches!(err, LlmError::FeatureNotSupported(_)),
207 "Expected FeatureNotSupported, got: {err:?}"
208 );
209 }
210
211 #[test]
212 fn default_supports_vision_returns_false() {
213 let llm = DummyLlm;
214 assert!(!llm.supports_vision());
215 }
216}