1use async_trait::async_trait;
4
5use crate::errors::LLMError;
6use crate::models::Message;
7
8#[derive(Debug, Clone, Default)]
10pub struct GenerateOptions {
11 pub temperature: Option<f32>,
13
14 pub max_tokens: Option<u32>,
16
17 pub json_mode: bool,
19}
20
21#[async_trait]
23pub trait LLM: Send + Sync {
24 async fn generate(
26 &self,
27 messages: &[Message],
28 options: GenerateOptions,
29 ) -> Result<String, LLMError>;
30
31 fn model_name(&self) -> &str;
33}
34
35pub async fn generate_json<T: serde::de::DeserializeOwned>(
37 llm: &dyn LLM,
38 messages: &[Message],
39 options: GenerateOptions,
40) -> Result<T, LLMError> {
41 let mut opts = options;
42 opts.json_mode = true;
43
44 let response = llm.generate(messages, opts).await?;
45
46 let json_str = extract_json(&response);
48
49 serde_json::from_str(&json_str)
50 .map_err(|e| LLMError::JsonParse(format!("{}: {}", e, json_str)))
51}
52
53fn extract_json(response: &str) -> String {
55 let response = response.trim();
56
57 if let Some(start) = response.find("```json") {
59 if let Some(end) = response[start + 7..].find("```") {
60 return response[start + 7..start + 7 + end].trim().to_string();
61 }
62 }
63
64 if let Some(start) = response.find("```") {
66 if let Some(end) = response[start + 3..].find("```") {
67 let content = response[start + 3..start + 3 + end].trim();
68 if let Some(newline) = content.find('\n') {
70 let first_line = &content[..newline];
71 if !first_line.starts_with('{') && !first_line.starts_with('[') {
72 return content[newline + 1..].trim().to_string();
73 }
74 }
75 return content.to_string();
76 }
77 }
78
79 if let Some(start) = response.find('{') {
81 if let Some(end) = response.rfind('}') {
82 return response[start..=end].to_string();
83 }
84 }
85
86 if let Some(start) = response.find('[') {
87 if let Some(end) = response.rfind(']') {
88 return response[start..=end].to_string();
89 }
90 }
91
92 response.to_string()
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98
99 #[test]
100 fn test_extract_json_raw() {
101 let input = r#"{"key": "value"}"#;
102 assert_eq!(extract_json(input), r#"{"key": "value"}"#);
103 }
104
105 #[test]
106 fn test_extract_json_code_block() {
107 let input = r#"```json
108{"key": "value"}
109```"#;
110 assert_eq!(extract_json(input), r#"{"key": "value"}"#);
111 }
112
113 #[test]
114 fn test_extract_json_with_text() {
115 let input = r#"Here is the result: {"key": "value"} as requested."#;
116 assert_eq!(extract_json(input), r#"{"key": "value"}"#);
117 }
118}