#[cfg(feature = "openai")]
mod openai;
#[cfg(feature = "openai")]
#[allow(unused_imports)]
pub use openai::{OpenAIProvider, OpenAiApiMode};
use crate::error::AgentResult;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[async_trait]
pub trait LLMProvider: Send + Sync {
async fn complete(
&self,
messages: Vec<Message>,
options: &CompletionOptions,
client: &reqwest::Client,
) -> AgentResult<CompletionResponse>;
fn provider_name(&self) -> &'static str;
fn is_configured(&self) -> bool;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: MessageContent,
}
impl Message {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".to_string(),
content: MessageContent::Text(content.into()),
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".to_string(),
content: MessageContent::Text(content.into()),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".to_string(),
content: MessageContent::Text(content.into()),
}
}
pub fn user_with_image(text: impl Into<String>, image_base64: impl Into<String>) -> Self {
Self {
role: "user".to_string(),
content: MessageContent::MultiPart(vec![
ContentPart::Text { text: text.into() },
ContentPart::ImageUrl {
image_url: ImageUrl {
url: format!("data:image/png;base64,{}", image_base64.into()),
},
},
]),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
MultiPart(Vec<ContentPart>),
}
impl MessageContent {
pub fn as_text(&self) -> &str {
match self {
Self::Text(s) => s,
Self::MultiPart(parts) => {
for part in parts {
if let ContentPart::Text { text } = part {
return text;
}
}
""
}
}
}
pub fn full_text(&self) -> String {
match self {
Self::Text(s) => s.clone(),
Self::MultiPart(parts) => parts
.iter()
.filter_map(|p| {
if let ContentPart::Text { text } = p {
Some(text.as_str())
} else {
None
}
})
.collect::<Vec<_>>()
.join(" "),
}
}
pub fn is_text(&self) -> bool {
matches!(self, Self::Text(_))
}
pub fn has_images(&self) -> bool {
match self {
Self::Text(_) => false,
Self::MultiPart(parts) => parts
.iter()
.any(|p| matches!(p, ContentPart::ImageUrl { .. })),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text { text: String },
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageUrl {
pub url: String,
}
#[derive(Debug, Clone)]
pub struct CompletionOptions {
pub temperature: f32,
pub max_tokens: u16,
pub json_mode: bool,
}
impl Default for CompletionOptions {
fn default() -> Self {
Self {
temperature: 0.1,
max_tokens: 4096,
json_mode: true,
}
}
}
#[derive(Debug, Clone)]
pub struct CompletionResponse {
pub content: String,
pub usage: TokenUsage,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
impl TokenUsage {
pub fn accumulate(&mut self, other: &TokenUsage) {
self.prompt_tokens += other.prompt_tokens;
self.completion_tokens += other.completion_tokens;
self.total_tokens += other.total_tokens;
}
}