use serde::{Deserialize, Serialize};
use std::path::PathBuf;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Modality {
Text,
Image,
Video,
Audio,
Structured,
}
#[derive(Debug, Clone)]
pub struct CapabilityError {
pub kind: CapabilityErrorKind,
pub message: String,
pub retryable: bool,
}
impl std::fmt::Display for CapabilityError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}: {}", self.kind, self.message)
}
}
impl std::error::Error for CapabilityError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CapabilityErrorKind {
Authentication,
RateLimit,
InvalidInput,
UnsupportedModality,
Network,
ProviderError,
StoreError,
NotFound,
Timeout,
}
impl CapabilityError {
#[must_use]
pub fn invalid_input(message: impl Into<String>) -> Self {
Self {
kind: CapabilityErrorKind::InvalidInput,
message: message.into(),
retryable: false,
}
}
#[must_use]
pub fn unsupported_modality(modality: Modality) -> Self {
Self {
kind: CapabilityErrorKind::UnsupportedModality,
message: format!("Modality {modality:?} is not supported by this provider"),
retryable: false,
}
}
#[must_use]
pub fn store(message: impl Into<String>) -> Self {
Self {
kind: CapabilityErrorKind::StoreError,
message: message.into(),
retryable: false,
}
}
#[must_use]
pub fn network(message: impl Into<String>) -> Self {
Self {
kind: CapabilityErrorKind::Network,
message: message.into(),
retryable: true,
}
}
#[must_use]
pub fn auth(message: impl Into<String>) -> Self {
Self {
kind: CapabilityErrorKind::Authentication,
message: message.into(),
retryable: false,
}
}
#[must_use]
pub fn not_found(message: impl Into<String>) -> Self {
Self {
kind: CapabilityErrorKind::NotFound,
message: message.into(),
retryable: false,
}
}
}
#[derive(Debug, Clone)]
pub enum EmbedInput {
Text(String),
ImageBytes {
data: Vec<u8>,
mime_type: String,
},
ImagePath(PathBuf),
VideoFrame {
path: PathBuf,
timestamp_ms: u64,
},
Mixed(Vec<EmbedInput>),
}
impl EmbedInput {
#[must_use]
pub fn text(s: impl Into<String>) -> Self {
Self::Text(s.into())
}
#[must_use]
pub fn image_path(path: impl Into<PathBuf>) -> Self {
Self::ImagePath(path.into())
}
#[must_use]
pub fn image_bytes(data: Vec<u8>, mime_type: impl Into<String>) -> Self {
Self::ImageBytes {
data,
mime_type: mime_type.into(),
}
}
#[must_use]
pub fn modality(&self) -> Modality {
match self {
Self::Text(_) => Modality::Text,
Self::ImageBytes { .. } | Self::ImagePath(_) => Modality::Image,
Self::VideoFrame { .. } => Modality::Video,
Self::Mixed(_) => Modality::Structured, }
}
}
#[derive(Debug, Clone)]
pub struct EmbedRequest {
pub inputs: Vec<EmbedInput>,
pub dimensions: Option<usize>,
pub task_instruction: Option<String>,
pub normalize: bool,
}
impl EmbedRequest {
#[must_use]
pub fn new(inputs: Vec<EmbedInput>) -> Self {
Self {
inputs,
dimensions: None,
task_instruction: None,
normalize: true,
}
}
#[must_use]
pub fn text(s: impl Into<String>) -> Self {
Self::new(vec![EmbedInput::text(s)])
}
#[must_use]
pub fn with_dimensions(mut self, dim: usize) -> Self {
self.dimensions = Some(dim);
self
}
#[must_use]
pub fn with_task(mut self, instruction: impl Into<String>) -> Self {
self.task_instruction = Some(instruction.into());
self
}
#[must_use]
pub fn with_normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
}
#[derive(Debug, Clone)]
pub struct EmbedResponse {
pub embeddings: Vec<Vec<f32>>,
pub model: String,
pub dimensions: usize,
pub usage: Option<EmbedUsage>,
}
#[derive(Debug, Clone, Default)]
pub struct EmbedUsage {
pub total_tokens: u32,
}
pub trait Embedding: Send + Sync {
fn name(&self) -> &str;
fn modalities(&self) -> Vec<Modality>;
fn default_dimensions(&self) -> usize;
fn embed(&self, request: &EmbedRequest) -> Result<EmbedResponse, CapabilityError>;
fn supports(&self, modality: Modality) -> bool {
self.modalities().contains(&modality)
}
}
#[derive(Debug, Clone)]
pub struct RerankRequest {
pub query: EmbedInput,
pub candidates: Vec<EmbedInput>,
pub top_k: Option<usize>,
pub min_score: Option<f64>,
}
impl RerankRequest {
#[must_use]
pub fn new(query: EmbedInput, candidates: Vec<EmbedInput>) -> Self {
Self {
query,
candidates,
top_k: None,
min_score: None,
}
}
#[must_use]
pub fn text(query: impl Into<String>, candidates: Vec<String>) -> Self {
Self::new(
EmbedInput::text(query),
candidates.into_iter().map(EmbedInput::text).collect(),
)
}
#[must_use]
pub fn with_top_k(mut self, k: usize) -> Self {
self.top_k = Some(k);
self
}
#[must_use]
pub fn with_min_score(mut self, score: f64) -> Self {
self.min_score = Some(score);
self
}
}
#[derive(Debug, Clone)]
pub struct RankedItem {
pub index: usize,
pub score: f64,
}
#[derive(Debug, Clone)]
pub struct RerankResponse {
pub ranked: Vec<RankedItem>,
pub model: String,
}
pub trait Reranking: Send + Sync {
fn name(&self) -> &str;
fn modalities(&self) -> Vec<Modality>;
fn rerank(&self, request: &RerankRequest) -> Result<RerankResponse, CapabilityError>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorRecord {
pub id: String,
pub vector: Vec<f32>,
pub payload: serde_json::Value,
}
#[derive(Debug, Clone)]
pub struct VectorQuery {
pub vector: Vec<f32>,
pub top_k: usize,
pub filter: Option<serde_json::Value>,
pub min_score: Option<f64>,
}
impl VectorQuery {
#[must_use]
pub fn new(vector: Vec<f32>, top_k: usize) -> Self {
Self {
vector,
top_k,
filter: None,
min_score: None,
}
}
#[must_use]
pub fn with_filter(mut self, filter: serde_json::Value) -> Self {
self.filter = Some(filter);
self
}
#[must_use]
pub fn with_min_score(mut self, score: f64) -> Self {
self.min_score = Some(score);
self
}
}
#[derive(Debug, Clone)]
pub struct VectorMatch {
pub id: String,
pub score: f64,
pub payload: serde_json::Value,
}
pub trait VectorRecall: Send + Sync {
fn name(&self) -> &str;
fn upsert(&self, record: &VectorRecord) -> Result<(), CapabilityError>;
fn upsert_batch(&self, records: &[VectorRecord]) -> Result<(), CapabilityError> {
for record in records {
self.upsert(record)?;
}
Ok(())
}
fn query(&self, query: &VectorQuery) -> Result<Vec<VectorMatch>, CapabilityError>;
fn delete(&self, id: &str) -> Result<(), CapabilityError>;
fn clear(&self) -> Result<(), CapabilityError>;
fn count(&self) -> Result<usize, CapabilityError>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphNode {
pub id: String,
pub label: String,
pub properties: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphEdge {
pub from: String,
pub to: String,
pub relationship: String,
pub properties: Option<serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct GraphQuery {
pub start_nodes: Vec<String>,
pub relationships: Vec<String>,
pub max_depth: usize,
pub limit: usize,
}
impl GraphQuery {
#[must_use]
pub fn from_node(id: impl Into<String>) -> Self {
Self {
start_nodes: vec![id.into()],
relationships: Vec::new(),
max_depth: 2,
limit: 100,
}
}
#[must_use]
pub fn with_relationships(mut self, rels: Vec<String>) -> Self {
self.relationships = rels;
self
}
#[must_use]
pub fn with_max_depth(mut self, depth: usize) -> Self {
self.max_depth = depth;
self
}
#[must_use]
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
}
#[derive(Debug, Clone)]
pub struct GraphResult {
pub nodes: Vec<GraphNode>,
pub edges: Vec<GraphEdge>,
}
pub trait GraphRecall: Send + Sync {
fn name(&self) -> &str;
fn add_node(&self, node: &GraphNode) -> Result<(), CapabilityError>;
fn add_edge(&self, edge: &GraphEdge) -> Result<(), CapabilityError>;
fn traverse(&self, query: &GraphQuery) -> Result<GraphResult, CapabilityError>;
fn find_nodes(
&self,
label: &str,
properties: Option<&serde_json::Value>,
) -> Result<Vec<GraphNode>, CapabilityError>;
fn get_node(&self, id: &str) -> Result<Option<GraphNode>, CapabilityError>;
fn delete_node(&self, id: &str) -> Result<(), CapabilityError>;
fn clear(&self) -> Result<(), CapabilityError>;
}
#[derive(Debug, Clone)]
pub struct CapabilityMetadata {
pub provider: String,
pub capabilities: Vec<CapabilityKind>,
pub modalities: Vec<Modality>,
pub is_local: bool,
pub typical_latency_ms: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum CapabilityKind {
Completion,
Embedding,
Reranking,
VectorRecall,
GraphRecall,
DocRecall,
Vision,
Audio,
CodeExecution,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn embed_input_modality() {
assert_eq!(EmbedInput::text("hello").modality(), Modality::Text);
assert_eq!(
EmbedInput::image_path("/foo.png").modality(),
Modality::Image
);
}
#[test]
fn embed_request_builder() {
let req = EmbedRequest::text("test")
.with_dimensions(512)
.with_task("retrieval")
.with_normalize(false);
assert_eq!(req.inputs.len(), 1);
assert_eq!(req.dimensions, Some(512));
assert_eq!(req.task_instruction, Some("retrieval".into()));
assert!(!req.normalize);
}
#[test]
fn rerank_request_builder() {
let req = RerankRequest::text("query", vec!["a".into(), "b".into()])
.with_top_k(5)
.with_min_score(0.5);
assert_eq!(req.candidates.len(), 2);
assert_eq!(req.top_k, Some(5));
assert_eq!(req.min_score, Some(0.5));
}
#[test]
fn vector_query_builder() {
let query = VectorQuery::new(vec![0.1, 0.2, 0.3], 10)
.with_min_score(0.8)
.with_filter(serde_json::json!({"type": "document"}));
assert_eq!(query.top_k, 10);
assert_eq!(query.min_score, Some(0.8));
assert!(query.filter.is_some());
}
#[test]
fn graph_query_builder() {
let query = GraphQuery::from_node("company-1")
.with_relationships(vec!["COMPETES_WITH".into()])
.with_max_depth(3)
.with_limit(50);
assert_eq!(query.start_nodes, vec!["company-1"]);
assert_eq!(query.max_depth, 3);
assert_eq!(query.limit, 50);
}
#[test]
fn capability_error_creation() {
let err = CapabilityError::unsupported_modality(Modality::Video);
assert_eq!(err.kind, CapabilityErrorKind::UnsupportedModality);
assert!(!err.retryable);
let err = CapabilityError::network("connection refused");
assert_eq!(err.kind, CapabilityErrorKind::Network);
assert!(err.retryable);
}
}