use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmRequest {
pub prompt: String,
pub system: Option<String>,
pub max_tokens: u32,
pub temperature: f64,
pub stop_sequences: Vec<String>,
}
impl LlmRequest {
#[must_use]
pub fn new(prompt: impl Into<String>) -> Self {
Self {
prompt: prompt.into(),
system: None,
max_tokens: 1024,
temperature: 0.7,
stop_sequences: Vec::new(),
}
}
#[must_use]
pub fn with_system(mut self, system: impl Into<String>) -> Self {
self.system = Some(system.into());
self
}
#[must_use]
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
#[must_use]
pub fn with_temperature(mut self, temperature: f64) -> Self {
self.temperature = temperature;
self
}
#[must_use]
pub fn with_stop_sequence(mut self, stop: impl Into<String>) -> Self {
self.stop_sequences.push(stop.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmResponse {
pub content: String,
pub model: String,
pub usage: TokenUsage,
pub finish_reason: FinishReason,
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FinishReason {
Stop,
MaxTokens,
StopSequence,
ContentFilter,
}
#[derive(Debug, Clone, Serialize, Deserialize, Error)]
#[error("{kind:?}: {message}")]
pub struct LlmError {
pub kind: LlmErrorKind,
pub message: String,
pub retryable: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LlmErrorKind {
Authentication,
RateLimit,
InvalidRequest,
ModelNotFound,
Network,
ProviderError,
ParseError,
Timeout,
}
impl LlmError {
#[must_use]
pub fn new(kind: LlmErrorKind, message: impl Into<String>, retryable: bool) -> Self {
Self {
kind,
message: message.into(),
retryable,
}
}
#[must_use]
pub fn auth(message: impl Into<String>) -> Self {
Self::new(LlmErrorKind::Authentication, message, false)
}
#[must_use]
pub fn rate_limit(message: impl Into<String>) -> Self {
Self::new(LlmErrorKind::RateLimit, message, true)
}
#[must_use]
pub fn network(message: impl Into<String>) -> Self {
Self::new(LlmErrorKind::Network, message, true)
}
#[must_use]
pub fn parse(message: impl Into<String>) -> Self {
Self::new(LlmErrorKind::ParseError, message, false)
}
#[must_use]
pub fn provider(message: impl Into<String>) -> Self {
Self::new(LlmErrorKind::ProviderError, message, false)
}
#[must_use]
pub fn timeout(message: impl Into<String>) -> Self {
Self::new(LlmErrorKind::Timeout, message, true)
}
}
pub trait LlmProvider: Send + Sync {
fn name(&self) -> &'static str;
fn model(&self) -> &str;
fn complete(&self, request: &LlmRequest) -> Result<LlmResponse, LlmError>;
fn provenance(&self, request_id: &str) -> String {
format!("{}:{}", self.model(), request_id)
}
fn health_check(&self) -> Result<(), LlmError> {
let request = LlmRequest::new("Say OK").with_max_tokens(1);
self.complete(&request).map(|_| ())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentRequirements {
pub max_cost_class: CostClass,
pub max_latency_ms: u32,
pub requires_reasoning: bool,
pub requires_web_search: bool,
pub min_quality: f64,
pub data_sovereignty: DataSovereignty,
pub compliance: ComplianceLevel,
pub requires_multilingual: bool,
}
impl AgentRequirements {
#[must_use]
pub fn new(max_cost_class: CostClass, max_latency_ms: u32, requires_reasoning: bool) -> Self {
Self {
max_cost_class,
max_latency_ms,
requires_reasoning,
requires_web_search: false,
min_quality: 0.0,
data_sovereignty: DataSovereignty::Any,
compliance: ComplianceLevel::None,
requires_multilingual: false,
}
}
#[must_use]
pub fn fast_cheap() -> Self {
Self::new(CostClass::VeryLow, 2000, false)
}
#[must_use]
pub fn balanced() -> Self {
Self::new(CostClass::Low, 5000, false).with_quality(0.8)
}
#[must_use]
pub fn powerful() -> Self {
Self::new(CostClass::High, 10000, true).with_quality(0.9)
}
#[must_use]
pub fn with_quality(mut self, quality: f64) -> Self {
self.min_quality = quality;
self
}
#[must_use]
pub fn with_web_search(mut self, required: bool) -> Self {
self.requires_web_search = required;
self
}
#[must_use]
pub fn with_data_sovereignty(mut self, sovereignty: DataSovereignty) -> Self {
self.data_sovereignty = sovereignty;
self
}
#[must_use]
pub fn with_compliance(mut self, compliance: ComplianceLevel) -> Self {
self.compliance = compliance;
self
}
#[must_use]
pub fn with_multilingual(mut self, required: bool) -> Self {
self.requires_multilingual = required;
self
}
}
pub trait ModelSelectorTrait: Send + Sync {
fn select(&self, requirements: &AgentRequirements) -> Result<(String, String), LlmError>;
}
pub use converge_traits::{ComplianceLevel, CostClass, DataSovereignty};