use crate::llm::traits::AiProvider;
use crate::llm::types::{ChatCompletionParams, ProviderResponse};
use crate::llm::utils::normalize_model_name;
use anyhow::Result;
use std::env;
#[derive(Debug, Clone)]
pub struct CloudflareWorkersAiProvider;
impl Default for CloudflareWorkersAiProvider {
fn default() -> Self {
Self::new()
}
}
impl CloudflareWorkersAiProvider {
pub fn new() -> Self {
Self
}
fn get_api_token(&self) -> Result<String> {
env::var(CLOUDFLARE_API_TOKEN_ENV)
.map_err(|_| anyhow::anyhow!("CLOUDFLARE_API_TOKEN not found in environment"))
}
fn get_account_id(&self) -> Result<String> {
env::var(CLOUDFLARE_ACCOUNT_ID_ENV)
.map_err(|_| anyhow::anyhow!("CLOUDFLARE_ACCOUNT_ID not found in environment"))
}
}
const CLOUDFLARE_API_TOKEN_ENV: &str = "CLOUDFLARE_API_TOKEN";
const CLOUDFLARE_ACCOUNT_ID_ENV: &str = "CLOUDFLARE_ACCOUNT_ID";
#[async_trait::async_trait]
impl AiProvider for CloudflareWorkersAiProvider {
fn name(&self) -> &str {
"cloudflare"
}
fn supports_model(&self, model: &str) -> bool {
let model_lower = normalize_model_name(model);
model_lower.contains("llama")
|| model_lower.contains("mistral")
|| model_lower.contains("qwen")
|| model_lower.contains("phi")
|| model_lower.contains("tinyllama")
|| model_lower.contains("gemma")
|| model_lower.contains("codellama")
|| model_lower.contains("deepseek")
|| model_lower.contains("neural-chat")
|| model_lower.contains("openchat")
|| model_lower.contains("starling")
|| model_lower.contains("zephyr")
|| model.starts_with("@cf/")
|| model.starts_with("@hf/")
}
fn get_api_key(&self) -> Result<String> {
let api_token = self.get_api_token()?;
let _account_id = self.get_account_id()?; Ok(api_token) }
fn supports_caching(&self, _model: &str) -> bool {
false
}
fn supports_vision(&self, _model: &str) -> bool {
false
}
fn get_max_input_tokens(&self, model: &str) -> usize {
let model_lower = normalize_model_name(model);
if model_lower.contains("llama-3.1") || model_lower.contains("llama-3.2") {
32_768 } else if model_lower.contains("llama-2") {
4_096 } else if model_lower.contains("mistral-7b") {
8_192 } else if model_lower.contains("qwen1.5") {
32_768 } else if model_lower.contains("codellama") {
16_384 } else if model_lower.contains("gemma") {
8_192 } else {
4_096 }
}
async fn chat_completion(&self, _params: ChatCompletionParams) -> Result<ProviderResponse> {
Err(anyhow::anyhow!(
"Cloudflare Workers AI provider not fully implemented in octolib"
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_supports_model() {
let provider = CloudflareWorkersAiProvider::new();
assert!(provider.supports_model("llama-3.1-70b-instruct"));
assert!(provider.supports_model("@cf/meta/llama-3.1-70b-instruct"));
assert!(provider.supports_model("@hf/meta/llama-3.1-8b-instruct"));
assert!(provider.supports_model("mistral-7b-instruct-v0.1"));
assert!(provider.supports_model("gemma-2-27b-it"));
assert!(!provider.supports_model("gpt-4"));
assert!(!provider.supports_model("claude-3"));
}
#[test]
fn test_supports_model_case_insensitive() {
let provider = CloudflareWorkersAiProvider::new();
assert!(provider.supports_model("LLAMA-3.1-70B-INSTRUCT"));
assert!(provider.supports_model("MISTRAL-7B-INSTRUCT-V0.1"));
assert!(provider.supports_model("Llama-3.1-70B-Instruct"));
assert!(provider.supports_model("GEMMA-2-27B-IT"));
}
}