Skip to main content

cognee_llm/
llm_trait.rs

1//! LLM trait definition for structured output generation.
2
3use async_trait::async_trait;
4use schemars::JsonSchema;
5use serde::{Serialize, de::DeserializeOwned};
6use serde_json::Value;
7
8use crate::error::{LlmError, LlmResult};
9use crate::schema::generate_json_schema;
10use crate::types::{GenerationOptions, GenerationResponse, Message, MessageRole};
11
12/// Object-safe base trait for LLM implementations.
13///
14/// Provides type-erased methods that work with `serde_json::Value` for JSON schemas
15/// and responses. For ergonomic generic methods, see [`LlmExt`].
16#[async_trait]
17pub trait Llm: Send + Sync {
18    /// Generate text completion from messages.
19    async fn generate(
20        &self,
21        messages: Vec<Message>,
22        options: Option<GenerationOptions>,
23    ) -> LlmResult<GenerationResponse>;
24
25    /// Generate structured output from text (type-erased).
26    ///
27    /// Takes a pre-built JSON schema and returns the raw JSON `Value`.
28    /// Prefer using [`LlmExt::create_structured_output`] for typed access.
29    async fn create_structured_output_raw(
30        &self,
31        text_input: &str,
32        system_prompt: &str,
33        json_schema: &Value,
34        options: Option<GenerationOptions>,
35    ) -> LlmResult<Value> {
36        let messages = vec![
37            Message {
38                role: MessageRole::System,
39                content: system_prompt.to_string(),
40            },
41            Message {
42                role: MessageRole::User,
43                content: text_input.to_string(),
44            },
45        ];
46        self.create_structured_output_with_messages_raw(messages, json_schema, options)
47            .await
48    }
49
50    /// Generate structured output from messages (type-erased).
51    ///
52    /// Takes a pre-built JSON schema and returns the raw JSON `Value`.
53    /// Prefer using [`LlmExt::create_structured_output_with_messages`] for typed access.
54    async fn create_structured_output_with_messages_raw(
55        &self,
56        messages: Vec<Message>,
57        json_schema: &Value,
58        options: Option<GenerationOptions>,
59    ) -> LlmResult<Value>;
60
61    /// Get the model identifier.
62    fn model(&self) -> &str;
63
64    /// Check if the LLM supports streaming.
65    fn supports_streaming(&self) -> bool {
66        false
67    }
68
69    /// Check if the LLM supports function calling / tool use.
70    fn supports_function_calling(&self) -> bool {
71        false
72    }
73
74    /// Get the maximum context length (in tokens) for this model.
75    fn max_context_length(&self) -> u32 {
76        4096
77    }
78
79    /// Describe the contents of an image using vision capabilities.
80    ///
81    /// Returns a text description of the image. The default implementation
82    /// returns `LlmError::FeatureNotSupported` — override in adapters that
83    /// support vision (e.g. OpenAI GPT-4o, Ollama llava).
84    ///
85    /// # Arguments
86    /// * `image_bytes` — Raw image bytes (PNG, JPEG, WebP, GIF, etc.)
87    /// * `mime_type` — MIME type string (must start with `"image/"`)
88    /// * `options` — Optional generation parameters; if `None`, the
89    ///   implementation should use hardcoded defaults matching the Python SDK
90    ///   (max_tokens=300).
91    async fn transcribe_image(
92        &self,
93        image_bytes: &[u8],
94        mime_type: &str,
95        options: Option<GenerationOptions>,
96    ) -> LlmResult<String> {
97        let _ = (image_bytes, mime_type, options);
98        Err(LlmError::FeatureNotSupported(format!(
99            "Vision is not supported by model: {}",
100            self.model()
101        )))
102    }
103
104    /// Whether this adapter supports image transcription.
105    ///
106    /// This is a best-effort heuristic based on the model name. A `true`
107    /// return does not guarantee the API will accept vision requests; a
108    /// `false` return does not prevent calling `transcribe_image` (which
109    /// will return `FeatureNotSupported` from the default impl, or attempt
110    /// the API call and surface a server error from a real adapter).
111    fn supports_vision(&self) -> bool {
112        false
113    }
114}
115
116/// Extension trait providing generic convenience methods on top of [`Llm`].
117/// Auto-implemented for all types that implement `Llm`.
118#[async_trait]
119pub trait LlmExt: Llm {
120    /// Generate structured output from text input.
121    ///
122    /// Generates a JSON schema from `T`, calls the type-erased
123    /// [`Llm::create_structured_output_raw`], and deserializes the result.
124    async fn create_structured_output<T>(
125        &self,
126        text_input: &str,
127        system_prompt: &str,
128        options: Option<GenerationOptions>,
129    ) -> LlmResult<T>
130    where
131        T: Serialize + DeserializeOwned + JsonSchema + Send,
132    {
133        let schema = generate_json_schema::<T>();
134        let value = self
135            .create_structured_output_raw(text_input, system_prompt, &schema, options)
136            .await?;
137        serde_json::from_value(value).map_err(|e| {
138            LlmError::DeserializationError(format!("Failed to deserialize structured output: {e}"))
139        })
140    }
141
142    /// Generate structured output from custom messages.
143    ///
144    /// Generates a JSON schema from `T`, calls the type-erased
145    /// [`Llm::create_structured_output_with_messages_raw`], and deserializes the result.
146    async fn create_structured_output_with_messages<T>(
147        &self,
148        messages: Vec<Message>,
149        options: Option<GenerationOptions>,
150    ) -> LlmResult<T>
151    where
152        T: Serialize + DeserializeOwned + JsonSchema + Send,
153    {
154        let schema = generate_json_schema::<T>();
155        let value = self
156            .create_structured_output_with_messages_raw(messages, &schema, options)
157            .await?;
158        serde_json::from_value(value).map_err(|e| {
159            LlmError::DeserializationError(format!("Failed to deserialize structured output: {e}"))
160        })
161    }
162}
163
164impl<T: Llm + ?Sized> LlmExt for T {}
165
166#[cfg(test)]
167mod tests {
168    #![allow(
169        clippy::unwrap_used,
170        clippy::expect_used,
171        reason = "test code — panics are acceptable"
172    )]
173    use super::*;
174
175    struct DummyLlm;
176
177    #[async_trait]
178    impl Llm for DummyLlm {
179        async fn generate(
180            &self,
181            _: Vec<Message>,
182            _: Option<GenerationOptions>,
183        ) -> LlmResult<GenerationResponse> {
184            unimplemented!()
185        }
186        async fn create_structured_output_with_messages_raw(
187            &self,
188            _: Vec<Message>,
189            _: &Value,
190            _: Option<GenerationOptions>,
191        ) -> LlmResult<Value> {
192            unimplemented!()
193        }
194        fn model(&self) -> &str {
195            "dummy"
196        }
197    }
198
199    #[tokio::test]
200    async fn default_transcribe_image_returns_feature_not_supported() {
201        let llm = DummyLlm;
202        let result = llm.transcribe_image(b"fake-png", "image/png", None).await;
203        assert!(result.is_err());
204        let err = result.unwrap_err();
205        assert!(
206            matches!(err, LlmError::FeatureNotSupported(_)),
207            "Expected FeatureNotSupported, got: {err:?}"
208        );
209    }
210
211    #[test]
212    fn default_supports_vision_returns_false() {
213        let llm = DummyLlm;
214        assert!(!llm.supports_vision());
215    }
216}