use std::collections::HashMap;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::error::Result;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
impl Message {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".to_string(),
content: content.into(),
name: None,
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".to_string(),
content: content.into(),
name: None,
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".to_string(),
content: content.into(),
name: None,
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ResponseFormat {
#[serde(rename = "type")]
pub format_type: String,
}
impl ResponseFormat {
pub fn text() -> Self {
Self {
format_type: "text".to_string(),
}
}
pub fn json_object() -> Self {
Self {
format_type: "json_object".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionDef,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDef {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub name: String,
pub arguments: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone)]
pub enum LlmResponse {
Text(String),
WithToolCalls {
content: Option<String>,
tool_calls: Vec<ToolCall>,
},
}
impl LlmResponse {
pub fn text(&self) -> Option<&str> {
match self {
LlmResponse::Text(s) => Some(s),
LlmResponse::WithToolCalls { content, .. } => content.as_deref(),
}
}
pub fn text_or_empty(&self) -> &str {
self.text().unwrap_or("")
}
}
#[async_trait]
pub trait LlmBase: Send + Sync {
async fn generate_response(
&self,
messages: Vec<Message>,
response_format: Option<ResponseFormat>,
tools: Option<Vec<Tool>>,
tool_choice: Option<String>,
) -> Result<LlmResponse>;
async fn generate(&self, messages: Vec<Message>) -> Result<String> {
let response = self.generate_response(messages, None, None, None).await?;
Ok(response.text_or_empty().to_string())
}
async fn generate_json(&self, messages: Vec<Message>) -> Result<String> {
let response = self
.generate_response(messages, Some(ResponseFormat::json_object()), None, None)
.await?;
Ok(response.text_or_empty().to_string())
}
}