use crate::api::{ModelAliasSpec, ModelTask};
use crate::error::Result;
use async_trait::async_trait;
use std::any::Any;
#[derive(Debug, Clone)]
pub struct ProviderCapabilities {
pub supported_tasks: Vec<ModelTask>,
}
#[derive(Debug, Clone)]
pub enum ProviderHealth {
Healthy,
Degraded(String),
Unhealthy(String),
}
#[async_trait]
pub trait ModelProvider: Send + Sync {
fn provider_id(&self) -> &'static str;
fn capabilities(&self) -> ProviderCapabilities;
async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle>;
async fn health(&self) -> ProviderHealth;
async fn warmup(&self) -> Result<()> {
Ok(())
}
}
pub type LoadedModelHandle = std::sync::Arc<dyn Any + Send + Sync>;
#[async_trait]
pub trait EmbeddingModel: Send + Sync + Any {
async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>>;
fn dimensions(&self) -> u32;
fn model_id(&self) -> &str;
async fn warmup(&self) -> Result<()> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ScoredDoc {
pub index: usize,
pub score: f32,
pub text: Option<String>,
}
#[async_trait]
pub trait RerankerModel: Send + Sync {
async fn rerank(&self, query: &str, docs: &[&str]) -> Result<Vec<ScoredDoc>>;
async fn warmup(&self) -> Result<()> {
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MessageRole {
System,
User,
Assistant,
}
#[derive(Debug, Clone)]
pub enum ImageInput {
Bytes { data: Vec<u8>, media_type: String },
Url(String),
}
#[derive(Debug, Clone)]
pub enum ContentBlock {
Text(String),
Image(ImageInput),
}
#[derive(Debug, Clone)]
pub struct Message {
pub role: MessageRole,
pub content: Vec<ContentBlock>,
}
impl Message {
pub fn user(text: impl Into<String>) -> Self {
Self {
role: MessageRole::User,
content: vec![ContentBlock::Text(text.into())],
}
}
pub fn assistant(text: impl Into<String>) -> Self {
Self {
role: MessageRole::Assistant,
content: vec![ContentBlock::Text(text.into())],
}
}
pub fn system(text: impl Into<String>) -> Self {
Self {
role: MessageRole::System,
content: vec![ContentBlock::Text(text.into())],
}
}
pub fn text(&self) -> String {
self.content
.iter()
.filter_map(|b| match b {
ContentBlock::Text(t) => Some(t.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join(" ")
}
}
#[derive(Debug, Clone, Default)]
pub struct GenerationOptions {
pub max_tokens: Option<usize>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub width: Option<u32>,
pub height: Option<u32>,
}
#[derive(Debug, Clone)]
pub struct GeneratedImage {
pub data: Vec<u8>,
pub media_type: String,
}
#[derive(Debug, Clone)]
pub struct AudioOutput {
pub pcm_data: Vec<f32>,
pub sample_rate: usize,
pub channels: usize,
}
#[derive(Debug, Clone)]
pub struct GenerationResult {
pub text: String,
pub usage: Option<TokenUsage>,
pub images: Vec<GeneratedImage>,
pub audio: Option<AudioOutput>,
}
#[derive(Debug, Clone)]
pub struct TokenUsage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
#[async_trait]
pub trait GeneratorModel: Send + Sync {
async fn generate(
&self,
messages: &[Message],
options: GenerationOptions,
) -> Result<GenerationResult>;
async fn warmup(&self) -> Result<()> {
Ok(())
}
}