Skip to main content

anyllm_cloudflare_worker/
embedding.rs

1//! `EmbeddingProvider` implementation for Cloudflare Workers AI.
2//!
3//! Wraps the `worker::Ai::run()` binding for text-embedding models like
4//! `@cf/baai/bge-base-en-v1.5`. Wire types mirror the native Workers AI
5//! embedding schema: `{ "text": [...] }` in, `{ "shape": [N, D], "data": [...] }`
6//! out.
7
8use anyllm::{
9    CapabilitySupport, EmbeddingCapability, EmbeddingProvider, EmbeddingRequest, EmbeddingResponse,
10    Error, Result,
11};
12use serde::{Deserialize, Serialize};
13
14use crate::Provider;
15use crate::error::map_worker_error;
16
17/// Workers AI embedding request body.
18#[derive(Debug, Serialize)]
19pub(crate) struct EmbedRequest {
20    pub text: Vec<String>,
21}
22
23/// Workers AI embedding response body.
24///
25/// The native API returns `{ "shape": [batch, dimensions], "data": [...] }`.
26/// `shape` is ignored — the portable response inspects `data` directly.
27#[derive(Debug, Deserialize)]
28pub(crate) struct EmbedResponse {
29    pub data: Vec<Vec<f32>>,
30}
31
32impl TryFrom<&EmbeddingRequest> for EmbedRequest {
33    type Error = Error;
34
35    fn try_from(request: &EmbeddingRequest) -> Result<Self> {
36        if request.dimensions.is_some() {
37            return Err(Error::Unsupported(
38                "cloudflare-worker embedding does not support output dimension selection".into(),
39            ));
40        }
41        if request.inputs.is_empty() {
42            return Err(Error::InvalidRequest(
43                "embedding request has no inputs".into(),
44            ));
45        }
46        Ok(Self {
47            text: request.inputs.clone(),
48        })
49    }
50}
51
52impl From<EmbedResponse> for EmbeddingResponse {
53    fn from(response: EmbedResponse) -> Self {
54        EmbeddingResponse::new(response.data)
55    }
56}
57
58impl EmbeddingProvider for Provider {
59    async fn embed(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
60        let cf_request = EmbedRequest::try_from(request)?;
61
62        let response: EmbedResponse = self
63            .ai
64            .run(&request.model, &cf_request)
65            .await
66            .map_err(map_worker_error)?;
67
68        Ok(response.into())
69    }
70
71    fn embedding_capability(
72        &self,
73        _model: &str,
74        capability: EmbeddingCapability,
75    ) -> CapabilitySupport {
76        match capability {
77            EmbeddingCapability::BatchInput => CapabilitySupport::Supported,
78            EmbeddingCapability::OutputDimensions => CapabilitySupport::Unsupported,
79            _ => CapabilitySupport::Unknown,
80        }
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87
88    #[test]
89    fn request_conversion_forwards_inputs() {
90        let request = EmbeddingRequest::new("@cf/baai/bge-base-en-v1.5")
91            .input("hello")
92            .input("world");
93
94        let cf = EmbedRequest::try_from(&request).unwrap();
95        assert_eq!(cf.text, vec!["hello".to_string(), "world".to_string()]);
96    }
97
98    #[test]
99    fn request_rejects_dimension_override() {
100        let request = EmbeddingRequest::new("@cf/baai/bge-base-en-v1.5")
101            .input("hi")
102            .dimensions(256);
103        let err = EmbedRequest::try_from(&request).unwrap_err();
104        assert!(matches!(err, Error::Unsupported(_)));
105    }
106
107    #[test]
108    fn request_rejects_empty_inputs() {
109        let request = EmbeddingRequest::new("@cf/baai/bge-base-en-v1.5");
110        let err = EmbedRequest::try_from(&request).unwrap_err();
111        assert!(matches!(err, Error::InvalidRequest(_)));
112    }
113
114    #[test]
115    fn response_conversion_preserves_vectors() {
116        let wire = EmbedResponse {
117            data: vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]],
118        };
119        let response: EmbeddingResponse = wire.into();
120        assert_eq!(response.embeddings.len(), 2);
121        assert_eq!(response.embeddings[0], vec![0.1, 0.2, 0.3]);
122        assert_eq!(response.embeddings[1], vec![0.4, 0.5, 0.6]);
123    }
124}