use async_trait::async_trait;
use inferd_proto::embed::{EmbedResolved, EmbedUsage};
use inferd_proto::v2::{ResolvedV2, StopReasonV2, ToolCallId, ToolUseInput, UsageV2};
use inferd_proto::{Resolved, StopReason, Usage};
use std::pin::Pin;
use tokio_stream::Stream;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TokenEvent {
Token(String),
Done {
stop_reason: StopReason,
usage: Usage,
},
}
pub type TokenStream = Pin<Box<dyn Stream<Item = TokenEvent> + Send>>;
#[derive(Debug, Clone, PartialEq)]
pub enum TokenEventV2 {
Text(String),
Thinking(String),
ToolUse {
tool_call_id: ToolCallId,
name: String,
input: ToolUseInput,
},
Done {
stop_reason: StopReasonV2,
usage: UsageV2,
},
}
pub type TokenStreamV2 = Pin<Box<dyn Stream<Item = TokenEventV2> + Send>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum AcceleratorKind {
#[default]
Cpu,
Cuda,
Metal,
Vulkan,
Rocm,
}
impl AcceleratorKind {
pub fn as_str(self) -> &'static str {
match self {
AcceleratorKind::Cpu => "cpu",
AcceleratorKind::Cuda => "cuda",
AcceleratorKind::Metal => "metal",
AcceleratorKind::Vulkan => "vulkan",
AcceleratorKind::Rocm => "rocm",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct AcceleratorInfo {
pub kind: AcceleratorKind,
pub gpu_layers: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct BackendCapabilities {
pub v2: bool,
pub vision: bool,
pub audio: bool,
pub video: bool,
pub tools: bool,
pub thinking: bool,
pub embed: bool,
pub accelerator: AcceleratorInfo,
}
#[derive(Debug, Clone, PartialEq)]
pub struct EmbedResult {
pub embeddings: Vec<Vec<f32>>,
pub dimensions: u32,
pub model: String,
pub usage: EmbedUsage,
}
#[derive(Debug, thiserror::Error)]
pub enum EmbedError {
#[error("backend not ready")]
NotReady,
#[error("embed not supported by this backend")]
Unsupported,
#[error("invalid request: {0}")]
InvalidRequest(String),
#[error("backend unavailable: {0}")]
Unavailable(String),
#[error("internal: {0}")]
Internal(String),
}
#[derive(Debug, thiserror::Error)]
pub enum GenerateError {
#[error("backend not ready")]
NotReady,
#[error("invalid request: {0}")]
InvalidRequest(String),
#[error("backend unavailable: {0}")]
Unavailable(String),
#[error("internal: {0}")]
Internal(String),
}
#[async_trait]
pub trait Backend: Send + Sync {
fn name(&self) -> &str;
fn ready(&self) -> bool;
fn capabilities(&self) -> BackendCapabilities {
BackendCapabilities::default()
}
async fn generate(&self, req: Resolved) -> Result<TokenStream, GenerateError>;
async fn generate_v2(&self, _req: ResolvedV2) -> Result<TokenStreamV2, GenerateError> {
Err(GenerateError::Internal(
"v2 not supported by this backend".into(),
))
}
async fn embed(&self, _req: EmbedResolved) -> Result<EmbedResult, EmbedError> {
Err(EmbedError::Unsupported)
}
async fn stop(&self, _timeout: std::time::Duration) -> Result<(), GenerateError> {
Ok(())
}
}