anyllm_cloudflare_worker/
embedding.rs1use anyllm::{
9 CapabilitySupport, EmbeddingCapability, EmbeddingProvider, EmbeddingRequest, EmbeddingResponse,
10 Error, Result,
11};
12use serde::{Deserialize, Serialize};
13
14use crate::Provider;
15use crate::error::map_worker_error;
16
17#[derive(Debug, Serialize)]
19pub(crate) struct EmbedRequest {
20 pub text: Vec<String>,
21}
22
23#[derive(Debug, Deserialize)]
28pub(crate) struct EmbedResponse {
29 pub data: Vec<Vec<f32>>,
30}
31
32impl TryFrom<&EmbeddingRequest> for EmbedRequest {
33 type Error = Error;
34
35 fn try_from(request: &EmbeddingRequest) -> Result<Self> {
36 if request.dimensions.is_some() {
37 return Err(Error::Unsupported(
38 "cloudflare-worker embedding does not support output dimension selection".into(),
39 ));
40 }
41 if request.inputs.is_empty() {
42 return Err(Error::InvalidRequest(
43 "embedding request has no inputs".into(),
44 ));
45 }
46 Ok(Self {
47 text: request.inputs.clone(),
48 })
49 }
50}
51
52impl From<EmbedResponse> for EmbeddingResponse {
53 fn from(response: EmbedResponse) -> Self {
54 EmbeddingResponse::new(response.data)
55 }
56}
57
58impl EmbeddingProvider for Provider {
59 async fn embed(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
60 let cf_request = EmbedRequest::try_from(request)?;
61
62 let response: EmbedResponse = self
63 .ai
64 .run(&request.model, &cf_request)
65 .await
66 .map_err(map_worker_error)?;
67
68 Ok(response.into())
69 }
70
71 fn embedding_capability(
72 &self,
73 _model: &str,
74 capability: EmbeddingCapability,
75 ) -> CapabilitySupport {
76 match capability {
77 EmbeddingCapability::BatchInput => CapabilitySupport::Supported,
78 EmbeddingCapability::OutputDimensions => CapabilitySupport::Unsupported,
79 _ => CapabilitySupport::Unknown,
80 }
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87
88 #[test]
89 fn request_conversion_forwards_inputs() {
90 let request = EmbeddingRequest::new("@cf/baai/bge-base-en-v1.5")
91 .input("hello")
92 .input("world");
93
94 let cf = EmbedRequest::try_from(&request).unwrap();
95 assert_eq!(cf.text, vec!["hello".to_string(), "world".to_string()]);
96 }
97
98 #[test]
99 fn request_rejects_dimension_override() {
100 let request = EmbeddingRequest::new("@cf/baai/bge-base-en-v1.5")
101 .input("hi")
102 .dimensions(256);
103 let err = EmbedRequest::try_from(&request).unwrap_err();
104 assert!(matches!(err, Error::Unsupported(_)));
105 }
106
107 #[test]
108 fn request_rejects_empty_inputs() {
109 let request = EmbeddingRequest::new("@cf/baai/bge-base-en-v1.5");
110 let err = EmbedRequest::try_from(&request).unwrap_err();
111 assert!(matches!(err, Error::InvalidRequest(_)));
112 }
113
114 #[test]
115 fn response_conversion_preserves_vectors() {
116 let wire = EmbedResponse {
117 data: vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]],
118 };
119 let response: EmbeddingResponse = wire.into();
120 assert_eq!(response.embeddings.len(), 2);
121 assert_eq!(response.embeddings[0], vec![0.1, 0.2, 0.3]);
122 assert_eq!(response.embeddings[1], vec![0.4, 0.5, 0.6]);
123 }
124}