use async_trait::async_trait;
use schemars::JsonSchema;
use serde::{Serialize, de::DeserializeOwned};
use serde_json::Value;
use crate::error::{LlmError, LlmResult};
use crate::schema::generate_json_schema;
use crate::types::{GenerationOptions, GenerationResponse, Message, MessageRole};
#[async_trait]
pub trait Llm: Send + Sync {
async fn generate(
&self,
messages: Vec<Message>,
options: Option<GenerationOptions>,
) -> LlmResult<GenerationResponse>;
async fn create_structured_output_raw(
&self,
text_input: &str,
system_prompt: &str,
json_schema: &Value,
options: Option<GenerationOptions>,
) -> LlmResult<Value> {
let messages = vec![
Message {
role: MessageRole::System,
content: system_prompt.to_string(),
},
Message {
role: MessageRole::User,
content: text_input.to_string(),
},
];
self.create_structured_output_with_messages_raw(messages, json_schema, options)
.await
}
async fn create_structured_output_with_messages_raw(
&self,
messages: Vec<Message>,
json_schema: &Value,
options: Option<GenerationOptions>,
) -> LlmResult<Value>;
fn model(&self) -> &str;
fn supports_streaming(&self) -> bool {
false
}
fn supports_function_calling(&self) -> bool {
false
}
fn max_context_length(&self) -> u32 {
4096
}
async fn transcribe_image(
&self,
image_bytes: &[u8],
mime_type: &str,
options: Option<GenerationOptions>,
) -> LlmResult<String> {
let _ = (image_bytes, mime_type, options);
Err(LlmError::FeatureNotSupported(format!(
"Vision is not supported by model: {}",
self.model()
)))
}
fn supports_vision(&self) -> bool {
false
}
}
#[async_trait]
pub trait LlmExt: Llm {
async fn create_structured_output<T>(
&self,
text_input: &str,
system_prompt: &str,
options: Option<GenerationOptions>,
) -> LlmResult<T>
where
T: Serialize + DeserializeOwned + JsonSchema + Send,
{
let schema = generate_json_schema::<T>();
let value = self
.create_structured_output_raw(text_input, system_prompt, &schema, options)
.await?;
serde_json::from_value(value).map_err(|e| {
LlmError::DeserializationError(format!("Failed to deserialize structured output: {e}"))
})
}
async fn create_structured_output_with_messages<T>(
&self,
messages: Vec<Message>,
options: Option<GenerationOptions>,
) -> LlmResult<T>
where
T: Serialize + DeserializeOwned + JsonSchema + Send,
{
let schema = generate_json_schema::<T>();
let value = self
.create_structured_output_with_messages_raw(messages, &schema, options)
.await?;
serde_json::from_value(value).map_err(|e| {
LlmError::DeserializationError(format!("Failed to deserialize structured output: {e}"))
})
}
}
impl<T: Llm + ?Sized> LlmExt for T {}
#[cfg(test)]
mod tests {
#![allow(
clippy::unwrap_used,
clippy::expect_used,
reason = "test code — panics are acceptable"
)]
use super::*;
struct DummyLlm;
#[async_trait]
impl Llm for DummyLlm {
async fn generate(
&self,
_: Vec<Message>,
_: Option<GenerationOptions>,
) -> LlmResult<GenerationResponse> {
unimplemented!()
}
async fn create_structured_output_with_messages_raw(
&self,
_: Vec<Message>,
_: &Value,
_: Option<GenerationOptions>,
) -> LlmResult<Value> {
unimplemented!()
}
fn model(&self) -> &str {
"dummy"
}
}
#[tokio::test]
async fn default_transcribe_image_returns_feature_not_supported() {
let llm = DummyLlm;
let result = llm.transcribe_image(b"fake-png", "image/png", None).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, LlmError::FeatureNotSupported(_)),
"Expected FeatureNotSupported, got: {err:?}"
);
}
#[test]
fn default_supports_vision_returns_false() {
let llm = DummyLlm;
assert!(!llm.supports_vision());
}
}