use async_trait::async_trait;
use crate::errors::LLMError;
use crate::models::Message;
#[derive(Debug, Clone, Default)]
pub struct GenerateOptions {
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub json_mode: bool,
}
#[async_trait]
pub trait LLM: Send + Sync {
async fn generate(
&self,
messages: &[Message],
options: GenerateOptions,
) -> Result<String, LLMError>;
fn model_name(&self) -> &str;
}
pub async fn generate_json<T: serde::de::DeserializeOwned>(
llm: &dyn LLM,
messages: &[Message],
options: GenerateOptions,
) -> Result<T, LLMError> {
let mut opts = options;
opts.json_mode = true;
let response = llm.generate(messages, opts).await?;
let json_str = extract_json(&response);
serde_json::from_str(&json_str)
.map_err(|e| LLMError::JsonParse(format!("{}: {}", e, json_str)))
}
fn extract_json(response: &str) -> String {
let response = response.trim();
if let Some(start) = response.find("```json") {
if let Some(end) = response[start + 7..].find("```") {
return response[start + 7..start + 7 + end].trim().to_string();
}
}
if let Some(start) = response.find("```") {
if let Some(end) = response[start + 3..].find("```") {
let content = response[start + 3..start + 3 + end].trim();
if let Some(newline) = content.find('\n') {
let first_line = &content[..newline];
if !first_line.starts_with('{') && !first_line.starts_with('[') {
return content[newline + 1..].trim().to_string();
}
}
return content.to_string();
}
}
if let Some(start) = response.find('{') {
if let Some(end) = response.rfind('}') {
return response[start..=end].to_string();
}
}
if let Some(start) = response.find('[') {
if let Some(end) = response.rfind(']') {
return response[start..=end].to_string();
}
}
response.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_json_raw() {
let input = r#"{"key": "value"}"#;
assert_eq!(extract_json(input), r#"{"key": "value"}"#);
}
#[test]
fn test_extract_json_code_block() {
let input = r#"```json
{"key": "value"}
```"#;
assert_eq!(extract_json(input), r#"{"key": "value"}"#);
}
#[test]
fn test_extract_json_with_text() {
let input = r#"Here is the result: {"key": "value"} as requested."#;
assert_eq!(extract_json(input), r#"{"key": "value"}"#);
}
}