Skip to main content

roder_api/
embeddings.rs

1use std::sync::Arc;
2
3use serde::{Deserialize, Serialize};
4
5use crate::extension::EmbeddingProviderId;
6
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
8#[serde(rename_all = "camelCase")]
9pub struct EmbeddingModelDescriptor {
10    pub id: String,
11    pub dimensions: usize,
12    #[serde(default)]
13    pub default: bool,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
17#[serde(rename_all = "camelCase")]
18pub struct EmbeddingProviderDescriptor {
19    pub id: EmbeddingProviderId,
20    pub name: String,
21    pub default_model: String,
22    pub models: Vec<EmbeddingModelDescriptor>,
23}
24
25#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
26#[serde(rename_all = "lowercase")]
27pub enum EmbeddingInputType {
28    Query,
29    #[default]
30    Document,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
34#[serde(rename_all = "camelCase")]
35pub struct EmbeddingRequest {
36    pub model: String,
37    pub inputs: Vec<String>,
38    #[serde(default)]
39    pub input_type: EmbeddingInputType,
40    #[serde(default, skip_serializing_if = "Option::is_none")]
41    pub dimensions: Option<usize>,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
45#[serde(rename_all = "camelCase")]
46pub struct EmbeddingVector {
47    pub index: usize,
48    pub values: Vec<f32>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52#[serde(rename_all = "camelCase")]
53pub struct EmbeddingResponse {
54    pub provider_id: EmbeddingProviderId,
55    pub model: String,
56    pub embeddings: Vec<EmbeddingVector>,
57}
58
59#[async_trait::async_trait]
60pub trait EmbeddingProvider: Send + Sync + 'static {
61    fn descriptor(&self) -> EmbeddingProviderDescriptor;
62
63    async fn embed(&self, request: EmbeddingRequest) -> anyhow::Result<EmbeddingResponse>;
64}
65
66#[derive(Clone)]
67pub struct EmbeddingProviderFactory {
68    provider: Arc<dyn EmbeddingProvider>,
69}
70
71impl EmbeddingProviderFactory {
72    pub fn new(provider: Arc<dyn EmbeddingProvider>) -> Self {
73        Self { provider }
74    }
75
76    pub fn id(&self) -> EmbeddingProviderId {
77        self.provider.descriptor().id
78    }
79
80    pub fn create(&self) -> Arc<dyn EmbeddingProvider> {
81        self.provider.clone()
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88
89    #[test]
90    fn embedding_request_serializes_query_intent() {
91        let request = EmbeddingRequest {
92            model: "zembed-1".to_string(),
93            inputs: vec!["what changed?".to_string()],
94            input_type: EmbeddingInputType::Query,
95            dimensions: Some(2560),
96        };
97
98        let json = serde_json::to_value(request).unwrap();
99
100        assert_eq!(json["inputType"], "query");
101        assert_eq!(json["dimensions"], 2560);
102    }
103
104    #[test]
105    fn embedding_request_defaults_to_document_intent() {
106        let request: EmbeddingRequest = serde_json::from_value(serde_json::json!({
107            "model": "zembed-1",
108            "inputs": ["stored memory"]
109        }))
110        .unwrap();
111
112        assert_eq!(request.input_type, EmbeddingInputType::Document);
113    }
114}