Skip to main content

rig_cat/provider/anthropic/
mod.rs

1//! Anthropic provider: completion model.
2
3use comp_cat_rs::effect::io::Io;
4use comp_cat_rs::effect::stream::Stream;
5use serde::{Deserialize, Serialize};
6
7use crate::error::Error;
8use crate::model::{
9    CompletionModel, CompletionRequest, CompletionResponse, StreamChunk,
10};
11
12/// Newtype for the Anthropic API key.
13#[derive(Clone)]
14pub struct ApiKey(String);
15
16impl ApiKey {
17    #[must_use]
18    pub fn new(key: String) -> Self { Self(key) }
19
20    fn as_str(&self) -> &str { &self.0 }
21}
22
23/// Newtype for a model name.
24#[derive(Clone)]
25pub struct ModelName(String);
26
27impl ModelName {
28    #[must_use]
29    pub fn new(name: String) -> Self { Self(name) }
30
31    fn as_str(&self) -> &str { &self.0 }
32}
33
34/// Anthropic completion model.
35pub struct AnthropicCompletion {
36    api_key: ApiKey,
37    model: ModelName,
38    max_tokens: u32,
39}
40
41impl AnthropicCompletion {
42    #[must_use]
43    pub fn new(api_key: ApiKey, model: ModelName, max_tokens: u32) -> Self {
44        Self { api_key, model, max_tokens }
45    }
46}
47
48// --- Request/response JSON shapes ---
49
50#[derive(Serialize)]
51struct MessagesRequest {
52    model: String,
53    max_tokens: u32,
54    #[serde(skip_serializing_if = "Option::is_none")]
55    system: Option<String>,
56    messages: Vec<AnthropicMessage>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    temperature: Option<f64>,
59}
60
61#[derive(Serialize)]
62struct AnthropicMessage {
63    role: String,
64    content: String,
65}
66
67#[derive(Deserialize)]
68struct MessagesResponse {
69    content: Vec<ContentBlock>,
70    model: String,
71}
72
73#[derive(Deserialize)]
74struct ContentBlock {
75    text: Option<String>,
76}
77
78// --- Trait impl ---
79
80impl CompletionModel for AnthropicCompletion {
81    fn complete(&self, request: CompletionRequest) -> Io<Error, CompletionResponse> {
82        let api_key = self.api_key.clone();
83        let model_name = self.model.clone();
84        let default_max = self.max_tokens;
85        Io::suspend(move || {
86            let system_msg = request.messages().iter()
87                .find(|m| matches!(m.role(), crate::model::Role::System))
88                .map(|m| m.content().to_owned());
89
90            let messages: Vec<AnthropicMessage> = request.messages().iter()
91                .filter(|m| !matches!(m.role(), crate::model::Role::System))
92                .map(|m| AnthropicMessage {
93                    role: match m.role() {
94                        crate::model::Role::Assistant => "assistant".to_owned(),
95                        crate::model::Role::User | crate::model::Role::System => "user".to_owned(),
96                    },
97                    content: m.content().to_owned(),
98                })
99                .collect();
100
101            let body = MessagesRequest {
102                model: model_name.as_str().to_owned(),
103                max_tokens: request.max_tokens().unwrap_or(default_max),
104                system: system_msg,
105                messages,
106                temperature: request.temperature(),
107            };
108
109            let resp: MessagesResponse = ureq::post("https://api.anthropic.com/v1/messages")
110                .header("x-api-key", api_key.as_str())
111                .header("anthropic-version", "2023-06-01")
112                .header("Content-Type", "application/json")
113                .send_json(&body)
114                .map_err(Error::from)?
115                .into_body()
116                .read_json()
117                .map_err(Error::from)?;
118
119            let content: String = resp.content.iter()
120                .filter_map(|b| b.text.clone())
121                .collect();
122
123            Ok(CompletionResponse::new(content, resp.model))
124        })
125    }
126
127    fn stream(&self, _request: CompletionRequest) -> Stream<Error, StreamChunk> {
128        // TODO: implement SSE streaming
129        Stream::empty()
130    }
131}