use async_trait::async_trait;
use futures::stream::BoxStream;
use serde::{Deserialize, Serialize};
use crate::error::Result;
use crate::message::{Message, ToolCall};
use crate::tool::ToolSchema;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: Option<u32>,
pub completion_tokens: Option<u32>,
pub total_tokens: Option<u32>,
}
impl Usage {
pub fn merge(&mut self, other: &Usage) {
if let Some(p) = other.prompt_tokens {
*self.prompt_tokens.get_or_insert(0) += p;
}
if let Some(c) = other.completion_tokens {
*self.completion_tokens.get_or_insert(0) += c;
}
if let Some(t) = other.total_tokens {
*self.total_tokens.get_or_insert(0) += t;
}
}
}
#[derive(Debug, Clone)]
pub struct LlmRequest {
pub model: String,
pub messages: Vec<Message>,
pub tools: Vec<ToolSchema>,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
}
#[derive(Debug, Clone, Default)]
pub struct LlmResponse {
pub content: Option<String>,
pub tool_calls: Vec<ToolCall>,
pub finish_reason: Option<String>,
pub usage: Option<Usage>,
}
#[derive(Debug, Clone)]
pub enum LlmStreamEvent {
Delta(String),
ToolCalls(Vec<ToolCall>),
Usage(Usage),
Done(Option<String>),
}
#[async_trait]
pub trait LlmClient: Send + Sync {
async fn complete(&self, req: LlmRequest) -> Result<LlmResponse> {
use futures::StreamExt;
let mut stream = self.stream(req).await?;
let mut acc = LlmResponse::default();
let mut buf = String::new();
while let Some(ev) = stream.next().await {
match ev? {
LlmStreamEvent::Delta(s) => buf.push_str(&s),
LlmStreamEvent::ToolCalls(calls) => acc.tool_calls = calls,
LlmStreamEvent::Usage(u) => acc.usage = Some(u),
LlmStreamEvent::Done(reason) => acc.finish_reason = reason,
}
}
if !buf.is_empty() {
acc.content = Some(buf);
}
Ok(acc)
}
async fn stream(
&self,
req: LlmRequest,
) -> Result<BoxStream<'static, Result<LlmStreamEvent>>>;
}