use async_trait::async_trait;
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use std::pin::Pin;
use crate::messaging::AgentMessage;
use crate::tools::ToolSchema;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmRequest {
pub system_prompt: String,
pub messages: Vec<AgentMessage>,
#[serde(default)]
pub tools: Vec<ToolSchema>,
}
impl LlmRequest {
pub fn new(system_prompt: impl Into<String>, messages: Vec<AgentMessage>) -> Self {
Self {
system_prompt: system_prompt.into(),
messages,
tools: Vec::new(),
}
}
pub fn with_tools(mut self, tools: Vec<ToolSchema>) -> Self {
self.tools = tools;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmResponse {
pub message: AgentMessage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum StreamChunk {
TextDelta(String),
Done {
message: AgentMessage,
},
Error(String),
}
pub type ChunkStream = Pin<Box<dyn Stream<Item = anyhow::Result<StreamChunk>> + Send>>;
#[async_trait]
pub trait LanguageModel: Send + Sync {
async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse>;
async fn generate_stream(&self, request: LlmRequest) -> anyhow::Result<ChunkStream> {
let response = self.generate(request).await?;
Ok(Box::pin(futures::stream::once(async move {
Ok(StreamChunk::Done {
message: response.message,
})
})))
}
}