use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::errors::NoosResult;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum AiProviderType {
Anthropic,
OpenAi,
Google,
Local,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderMessage {
pub role: MessageRole,
pub content: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
System,
User,
Assistant,
}
#[derive(Debug, Clone)]
pub struct CompletionRequest {
pub model: String,
pub messages: Vec<ProviderMessage>,
pub system_prompt: Option<String>,
pub max_tokens: u32,
pub temperature: f32,
pub stream: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionResponse {
pub text: String,
pub usage: TokenUsage,
pub model: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TokenUsage {
pub input_tokens: u32,
pub output_tokens: u32,
}
#[derive(Debug, Clone)]
pub enum StreamChunk {
TextDelta(String),
Usage(TokenUsage),
Done,
Error(String),
}
#[async_trait]
pub trait AiProvider: Send + Sync {
fn provider_type(&self) -> AiProviderType;
async fn complete(&self, request: CompletionRequest) -> NoosResult<CompletionResponse>;
async fn stream(
&self,
request: CompletionRequest,
sender: tokio::sync::mpsc::Sender<StreamChunk>,
) -> NoosResult<()>;
}
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
async fn embed(&self, text: &str) -> NoosResult<Vec<f32>>;
fn dimension(&self) -> usize;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn completion_request_builds() {
let req = CompletionRequest {
model: "claude-3-5-sonnet".into(),
messages: vec![ProviderMessage {
role: MessageRole::User,
content: "Hello".into(),
}],
system_prompt: Some("You are helpful".into()),
max_tokens: 1024,
temperature: 0.7,
stream: false,
};
assert_eq!(req.messages.len(), 1);
}
#[test]
fn token_usage_default() {
let usage = TokenUsage::default();
assert_eq!(usage.input_tokens, 0);
assert_eq!(usage.output_tokens, 0);
}
}