mem0_rust/llms/
traits.rs

1//! LLM trait definition.
2
3use async_trait::async_trait;
4
5use crate::errors::LLMError;
6use crate::models::Message;
7
8/// Options for LLM generation
9#[derive(Debug, Clone, Default)]
10pub struct GenerateOptions {
11    /// Temperature (0.0 to 1.0)
12    pub temperature: Option<f32>,
13
14    /// Maximum tokens to generate
15    pub max_tokens: Option<u32>,
16
17    /// Force JSON output
18    pub json_mode: bool,
19}
20
21/// Trait for LLM providers
22#[async_trait]
23pub trait LLM: Send + Sync {
24    /// Generate a text response
25    async fn generate(
26        &self,
27        messages: &[Message],
28        options: GenerateOptions,
29    ) -> Result<String, LLMError>;
30
31    /// Get the model name
32    fn model_name(&self) -> &str;
33}
34
35/// Generate and parse a JSON response (standalone function)
36pub async fn generate_json<T: serde::de::DeserializeOwned>(
37    llm: &dyn LLM,
38    messages: &[Message],
39    options: GenerateOptions,
40) -> Result<T, LLMError> {
41    let mut opts = options;
42    opts.json_mode = true;
43
44    let response = llm.generate(messages, opts).await?;
45
46    // Try to extract JSON from response (handle markdown code blocks)
47    let json_str = extract_json(&response);
48
49    serde_json::from_str(&json_str)
50        .map_err(|e| LLMError::JsonParse(format!("{}: {}", e, json_str)))
51}
52
53/// Extract JSON from response (handles markdown code blocks)
54fn extract_json(response: &str) -> String {
55    let response = response.trim();
56
57    // Try to extract from ```json ... ``` blocks
58    if let Some(start) = response.find("```json") {
59        if let Some(end) = response[start + 7..].find("```") {
60            return response[start + 7..start + 7 + end].trim().to_string();
61        }
62    }
63
64    // Try to extract from ``` ... ``` blocks
65    if let Some(start) = response.find("```") {
66        if let Some(end) = response[start + 3..].find("```") {
67            let content = response[start + 3..start + 3 + end].trim();
68            // Skip language identifier if present
69            if let Some(newline) = content.find('\n') {
70                let first_line = &content[..newline];
71                if !first_line.starts_with('{') && !first_line.starts_with('[') {
72                    return content[newline + 1..].trim().to_string();
73                }
74            }
75            return content.to_string();
76        }
77    }
78
79    // Try to find raw JSON object or array
80    if let Some(start) = response.find('{') {
81        if let Some(end) = response.rfind('}') {
82            return response[start..=end].to_string();
83        }
84    }
85
86    if let Some(start) = response.find('[') {
87        if let Some(end) = response.rfind(']') {
88            return response[start..=end].to_string();
89        }
90    }
91
92    response.to_string()
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98
99    #[test]
100    fn test_extract_json_raw() {
101        let input = r#"{"key": "value"}"#;
102        assert_eq!(extract_json(input), r#"{"key": "value"}"#);
103    }
104
105    #[test]
106    fn test_extract_json_code_block() {
107        let input = r#"```json
108{"key": "value"}
109```"#;
110        assert_eq!(extract_json(input), r#"{"key": "value"}"#);
111    }
112
113    #[test]
114    fn test_extract_json_with_text() {
115        let input = r#"Here is the result: {"key": "value"} as requested."#;
116        assert_eq!(extract_json(input), r#"{"key": "value"}"#);
117    }
118}