Skip to main content

anyllm_cloudflare_worker/
chat.rs

1//! `ChatProvider` implementation for Cloudflare Workers AI via the `worker::Ai` binding.
2
3#[cfg(feature = "extract")]
4use anyllm::ExtractExt;
5use anyllm::{
6    CapabilitySupport, ChatCapability, ChatProvider, ChatRequest, ChatResponse, ChatStream,
7    ProviderIdentity, Result,
8};
9
10use crate::Provider;
11use crate::error::map_worker_error;
12use crate::streaming::byte_stream_to_chat_stream;
13use crate::wire;
14
15/// The `gen_ai.provider.name` value reported by this provider.
16///
17/// Matches the identity reported by `anyllm-openai-compat`'s Cloudflare preset
18/// so dashboards keyed on `gen_ai.provider.name` see a single value regardless
19/// of which transport the user picked.
20pub(crate) const PROVIDER_NAME: &str = "cloudflare";
21
22impl ProviderIdentity for Provider {
23    fn provider_name(&self) -> &'static str {
24        PROVIDER_NAME
25    }
26}
27
28impl ChatProvider for Provider {
29    type Stream = ChatStream;
30
31    async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse> {
32        let cf_request = wire::ChatRequest::try_from(request)?;
33
34        let response: wire::ChatResponse = self
35            .ai
36            .run(&request.model, &cf_request)
37            .await
38            .map_err(map_worker_error)?;
39
40        response.try_into()
41    }
42
43    async fn chat_stream(&self, request: &ChatRequest) -> Result<ChatStream> {
44        wire::reject_unsupported_streaming_request_features(request)?;
45
46        let mut cf_request = wire::ChatRequest::try_from(request)?;
47        cf_request.stream = Some(true);
48
49        let byte_stream = self
50            .ai
51            .run_bytes(&request.model, &cf_request)
52            .await
53            .map_err(map_worker_error)?;
54
55        Ok(byte_stream_to_chat_stream(byte_stream))
56    }
57
58    fn chat_capability(&self, model: &str, capability: ChatCapability) -> CapabilitySupport {
59        if let Some(support) = self
60            .chat_capability_resolver
61            .as_ref()
62            .and_then(|resolver| resolver.chat_capability(model, capability))
63        {
64            support
65        } else {
66            self.builtin_chat_capability(model, capability)
67        }
68    }
69}
70
71#[cfg(feature = "extract")]
72impl ExtractExt for Provider {}
73
74#[cfg(test)]
75mod tests {
76    use super::PROVIDER_NAME;
77
78    #[test]
79    fn provider_identity_is_cloudflare() {
80        // The wire identity is part of the public contract: dashboards and
81        // OTEL backends keyed on `gen_ai.provider.name` break if it shifts.
82        // Lock it down so a refactor cannot quietly rename it.
83        assert_eq!(PROVIDER_NAME, "cloudflare");
84    }
85}