rig_cat/provider/anthropic/
mod.rs1use comp_cat_rs::effect::io::Io;
4use comp_cat_rs::effect::stream::Stream;
5use serde::{Deserialize, Serialize};
6
7use crate::error::Error;
8use crate::model::{
9 CompletionModel, CompletionRequest, CompletionResponse, StreamChunk,
10};
11
12#[derive(Clone)]
14pub struct ApiKey(String);
15
16impl ApiKey {
17 #[must_use]
18 pub fn new(key: String) -> Self { Self(key) }
19
20 fn as_str(&self) -> &str { &self.0 }
21}
22
23#[derive(Clone)]
25pub struct ModelName(String);
26
27impl ModelName {
28 #[must_use]
29 pub fn new(name: String) -> Self { Self(name) }
30
31 fn as_str(&self) -> &str { &self.0 }
32}
33
34pub struct AnthropicCompletion {
36 api_key: ApiKey,
37 model: ModelName,
38 max_tokens: u32,
39}
40
41impl AnthropicCompletion {
42 #[must_use]
43 pub fn new(api_key: ApiKey, model: ModelName, max_tokens: u32) -> Self {
44 Self { api_key, model, max_tokens }
45 }
46}
47
48#[derive(Serialize)]
51struct MessagesRequest {
52 model: String,
53 max_tokens: u32,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 system: Option<String>,
56 messages: Vec<AnthropicMessage>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 temperature: Option<f64>,
59}
60
61#[derive(Serialize)]
62struct AnthropicMessage {
63 role: String,
64 content: String,
65}
66
67#[derive(Deserialize)]
68struct MessagesResponse {
69 content: Vec<ContentBlock>,
70 model: String,
71}
72
73#[derive(Deserialize)]
74struct ContentBlock {
75 text: Option<String>,
76}
77
78impl CompletionModel for AnthropicCompletion {
81 fn complete(&self, request: CompletionRequest) -> Io<Error, CompletionResponse> {
82 let api_key = self.api_key.clone();
83 let model_name = self.model.clone();
84 let default_max = self.max_tokens;
85 Io::suspend(move || {
86 let system_msg = request.messages().iter()
87 .find(|m| matches!(m.role(), crate::model::Role::System))
88 .map(|m| m.content().to_owned());
89
90 let messages: Vec<AnthropicMessage> = request.messages().iter()
91 .filter(|m| !matches!(m.role(), crate::model::Role::System))
92 .map(|m| AnthropicMessage {
93 role: match m.role() {
94 crate::model::Role::Assistant => "assistant".to_owned(),
95 crate::model::Role::User | crate::model::Role::System => "user".to_owned(),
96 },
97 content: m.content().to_owned(),
98 })
99 .collect();
100
101 let body = MessagesRequest {
102 model: model_name.as_str().to_owned(),
103 max_tokens: request.max_tokens().unwrap_or(default_max),
104 system: system_msg,
105 messages,
106 temperature: request.temperature(),
107 };
108
109 let resp: MessagesResponse = ureq::post("https://api.anthropic.com/v1/messages")
110 .header("x-api-key", api_key.as_str())
111 .header("anthropic-version", "2023-06-01")
112 .header("Content-Type", "application/json")
113 .send_json(&body)
114 .map_err(Error::from)?
115 .into_body()
116 .read_json()
117 .map_err(Error::from)?;
118
119 let content: String = resp.content.iter()
120 .filter_map(|b| b.text.clone())
121 .collect();
122
123 Ok(CompletionResponse::new(content, resp.model))
124 })
125 }
126
127 fn stream(&self, _request: CompletionRequest) -> Stream<Error, StreamChunk> {
128 Stream::empty()
130 }
131}