1use std::sync::Arc;
2
3use serde::{Deserialize, Serialize};
4
5use crate::extension::EmbeddingProviderId;
6
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
8#[serde(rename_all = "camelCase")]
9pub struct EmbeddingModelDescriptor {
10 pub id: String,
11 pub dimensions: usize,
12 #[serde(default)]
13 pub default: bool,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
17#[serde(rename_all = "camelCase")]
18pub struct EmbeddingProviderDescriptor {
19 pub id: EmbeddingProviderId,
20 pub name: String,
21 pub default_model: String,
22 pub models: Vec<EmbeddingModelDescriptor>,
23}
24
25#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
26#[serde(rename_all = "lowercase")]
27pub enum EmbeddingInputType {
28 Query,
29 #[default]
30 Document,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
34#[serde(rename_all = "camelCase")]
35pub struct EmbeddingRequest {
36 pub model: String,
37 pub inputs: Vec<String>,
38 #[serde(default)]
39 pub input_type: EmbeddingInputType,
40 #[serde(default, skip_serializing_if = "Option::is_none")]
41 pub dimensions: Option<usize>,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
45#[serde(rename_all = "camelCase")]
46pub struct EmbeddingVector {
47 pub index: usize,
48 pub values: Vec<f32>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52#[serde(rename_all = "camelCase")]
53pub struct EmbeddingResponse {
54 pub provider_id: EmbeddingProviderId,
55 pub model: String,
56 pub embeddings: Vec<EmbeddingVector>,
57}
58
59#[async_trait::async_trait]
60pub trait EmbeddingProvider: Send + Sync + 'static {
61 fn descriptor(&self) -> EmbeddingProviderDescriptor;
62
63 async fn embed(&self, request: EmbeddingRequest) -> anyhow::Result<EmbeddingResponse>;
64}
65
66#[derive(Clone)]
67pub struct EmbeddingProviderFactory {
68 provider: Arc<dyn EmbeddingProvider>,
69}
70
71impl EmbeddingProviderFactory {
72 pub fn new(provider: Arc<dyn EmbeddingProvider>) -> Self {
73 Self { provider }
74 }
75
76 pub fn id(&self) -> EmbeddingProviderId {
77 self.provider.descriptor().id
78 }
79
80 pub fn create(&self) -> Arc<dyn EmbeddingProvider> {
81 self.provider.clone()
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88
89 #[test]
90 fn embedding_request_serializes_query_intent() {
91 let request = EmbeddingRequest {
92 model: "zembed-1".to_string(),
93 inputs: vec!["what changed?".to_string()],
94 input_type: EmbeddingInputType::Query,
95 dimensions: Some(2560),
96 };
97
98 let json = serde_json::to_value(request).unwrap();
99
100 assert_eq!(json["inputType"], "query");
101 assert_eq!(json["dimensions"], 2560);
102 }
103
104 #[test]
105 fn embedding_request_defaults_to_document_intent() {
106 let request: EmbeddingRequest = serde_json::from_value(serde_json::json!({
107 "model": "zembed-1",
108 "inputs": ["stored memory"]
109 }))
110 .unwrap();
111
112 assert_eq!(request.input_type, EmbeddingInputType::Document);
113 }
114}