#![warn(missing_docs)]
use async_trait::async_trait;
use futures::stream::BoxStream;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub mod error;
pub mod middleware;
pub mod mock;
pub mod ratelimit;
pub mod types;
pub use error::{Error, Result};
pub use middleware::{Layer, Middleware, ProviderExt};
pub use mock::MockProvider;
pub use ratelimit::TokenBucket;
#[async_trait]
pub trait Provider: Send + Sync {
async fn complete(&self, req: Request) -> Result<Response>;
async fn stream(&self, req: Request) -> Result<BoxStream<'static, Result<Chunk>>>;
}
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
async fn embed(&self, inputs: Vec<String>) -> Result<Vec<Vec<f32>>>;
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Message {
pub role: Role,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
#[serde(rename = "function")]
Function,
Tool,
}
impl Message {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: Role::System,
content: content.into(),
name: None,
tool_calls: None,
tool_call_id: None,
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: Role::User,
content: content.into(),
name: None,
tool_calls: None,
tool_call_id: None,
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: Role::Assistant,
content: content.into(),
name: None,
tool_calls: None,
tool_call_id: None,
}
}
pub fn tool_result(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
Self {
role: Role::Tool,
content: content.into(),
name: None,
tool_calls: None,
tool_call_id: Some(tool_call_id.into()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: ToolCallFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCallFunction {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Request {
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
#[serde(skip_serializing_if = "HashMap::is_empty", default)]
pub extra: HashMap<String, serde_json::Value>,
}
impl Request {
pub fn new() -> Self {
Self::default()
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
self.messages = messages;
self
}
pub fn with_message(mut self, message: Message) -> Self {
self.messages.push(message);
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn with_json_mode(mut self) -> Self {
self.response_format = Some(ResponseFormat::json_object());
self
}
pub fn with_extra(
mut self,
key: impl Into<String>,
value: impl Into<serde_json::Value>,
) -> Self {
self.extra.insert(key.into(), value.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponseFormat {
#[serde(rename = "type")]
pub format_type: String,
}
impl ResponseFormat {
pub fn json_object() -> Self {
Self {
format_type: "json_object".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Response {
pub id: String,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Option<Usage>,
pub created: Option<u64>,
}
impl Response {
pub fn content(&self) -> &str {
self.choices
.first()
.map(|c| c.message.content.as_str())
.unwrap_or("")
}
pub fn usage(&self) -> Option<&Usage> {
self.usage.as_ref()
}
pub fn tool_calls(&self) -> Option<&Vec<ToolCall>> {
self.choices
.first()
.and_then(|c| c.message.tool_calls.as_ref())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
pub index: u32,
pub message: Message,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
impl Usage {
pub fn calculate_cost(&self, prompt_price: f64, completion_price: f64) -> f64 {
let prompt_cost = (self.prompt_tokens as f64 / 1000.0) * prompt_price;
let completion_cost = (self.completion_tokens as f64 / 1000.0) * completion_price;
prompt_cost + completion_cost
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Chunk {
pub id: String,
pub model: String,
pub delta: Delta,
pub finish_reason: Option<String>,
}
impl Chunk {
pub fn content(&self) -> &str {
&self.delta.content
}
pub fn is_finished(&self) -> bool {
self.finish_reason.is_some()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Delta {
pub role: Option<Role>,
pub content: String,
}
#[derive(Debug, Clone)]
pub struct ProviderConfig {
pub api_key: String,
pub base_url: String,
pub timeout_seconds: u64,
pub max_retries: u32,
}
impl ProviderConfig {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
base_url: String::new(),
timeout_seconds: 60,
max_retries: 3,
}
}
pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into();
self
}
pub fn with_timeout(mut self, seconds: u64) -> Self {
self.timeout_seconds = seconds;
self
}
pub fn with_max_retries(mut self, retries: u32) -> Self {
self.max_retries = retries;
self
}
}