Skip to main content

albert_api/
types.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4pub const TERNLANG_VERSION: &str = "2023-06-01";
5
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
7pub struct MessageRequest {
8    pub model: String,
9    pub max_tokens: Option<u32>,
10    pub messages: Vec<InputMessage>,
11    #[serde(skip_serializing_if = "Option::is_none")]
12    pub system: Option<String>,
13    #[serde(skip_serializing_if = "Option::is_none")]
14    pub tools: Option<Vec<ToolDefinition>>,
15    #[serde(skip_serializing_if = "Option::is_none")]
16    pub tool_choice: Option<ToolChoice>,
17    #[serde(default, skip_serializing_if = "std::ops::Not::not")]
18    pub stream: bool,
19}
20
21impl MessageRequest {
22    #[must_use]
23    pub fn with_streaming(mut self) -> Self {
24        self.stream = true;
25        self
26    }
27}
28
29#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
30pub struct InputMessage {
31    pub role: String,
32    pub content: Vec<InputContentBlock>,
33}
34
35impl InputMessage {
36    #[must_use]
37    pub fn user_text(text: impl Into<String>) -> Self {
38        Self {
39            role: "user".to_string(),
40            content: vec![InputContentBlock::Text { text: text.into() }],
41        }
42    }
43
44    #[must_use]
45    pub fn user_tool_result(
46        tool_use_id: impl Into<String>,
47        content: impl Into<String>,
48        is_error: bool,
49    ) -> Self {
50        Self {
51            role: "user".to_string(),
52            content: vec![InputContentBlock::ToolResult {
53                tool_use_id: tool_use_id.into(),
54                content: vec![ToolResultContentBlock::Text {
55                    text: content.into(),
56                }],
57                is_error,
58            }],
59        }
60    }
61}
62
63#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
64#[serde(tag = "type", rename_all = "snake_case")]
65pub enum InputContentBlock {
66    Text {
67        text: String,
68    },
69    ToolUse {
70        id: String,
71        name: String,
72        input: Value,
73    },
74    ToolResult {
75        tool_use_id: String,
76        content: Vec<ToolResultContentBlock>,
77        #[serde(default, skip_serializing_if = "std::ops::Not::not")]
78        is_error: bool,
79    },
80}
81
82#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
83#[serde(tag = "type", rename_all = "snake_case")]
84pub enum ToolResultContentBlock {
85    Text { text: String },
86    Json { value: Value },
87}
88
89#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
90pub struct ToolDefinition {
91    pub name: String,
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub description: Option<String>,
94    pub input_schema: Value,
95}
96
97#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
98#[serde(tag = "type", rename_all = "snake_case")]
99pub enum ToolChoice {
100    Auto,
101    Any,
102    Tool { name: String },
103}
104
105#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
106pub struct MessageResponse {
107    pub id: String,
108    #[serde(rename = "type")]
109    pub kind: String,
110    pub role: String,
111    pub content: Vec<OutputContentBlock>,
112    pub model: String,
113    #[serde(default)]
114    pub stop_reason: Option<String>,
115    #[serde(default)]
116    pub stop_sequence: Option<String>,
117    pub usage: Usage,
118    #[serde(default)]
119    pub request_id: Option<String>,
120}
121
122impl MessageResponse {
123    #[must_use]
124    pub fn total_tokens(&self) -> u32 {
125        self.usage.total_tokens()
126    }
127}
128
129#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
130#[serde(tag = "type", rename_all = "snake_case")]
131pub enum OutputContentBlock {
132    Text {
133        text: String,
134    },
135    ToolUse {
136        id: String,
137        name: String,
138        input: Value,
139    },
140}
141
142#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
143pub struct Usage {
144    pub input_tokens: u32,
145    #[serde(default)]
146    pub cache_creation_input_tokens: u32,
147    #[serde(default)]
148    pub cache_read_input_tokens: u32,
149    pub output_tokens: u32,
150}
151
152impl Usage {
153    #[must_use]
154    pub const fn total_tokens(&self) -> u32 {
155        self.input_tokens + self.output_tokens
156    }
157}
158
159#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
160pub struct MessageStartEvent {
161    pub message: MessageResponse,
162}
163
164#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
165pub struct MessageDeltaEvent {
166    pub delta: MessageDelta,
167    pub usage: Usage,
168}
169
170#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
171pub struct MessageDelta {
172    #[serde(default)]
173    pub stop_reason: Option<String>,
174    #[serde(default)]
175    pub stop_sequence: Option<String>,
176}
177
178#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
179pub struct ContentBlockStartEvent {
180    pub index: u32,
181    pub content_block: OutputContentBlock,
182}
183
184#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
185pub struct ContentBlockDeltaEvent {
186    pub index: u32,
187    pub delta: ContentBlockDelta,
188}
189
190#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
191#[serde(tag = "type", rename_all = "snake_case")]
192pub enum ContentBlockDelta {
193    TextDelta { text: String },
194    InputJsonDelta { partial_json: String },
195}
196
197#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
198pub struct ContentBlockStopEvent {
199    pub index: u32,
200}
201
202#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
203pub struct MessageStopEvent {}
204
205#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
206#[serde(tag = "type", rename_all = "snake_case")]
207pub enum StreamEvent {
208    MessageStart(MessageStartEvent),
209    MessageDelta(MessageDeltaEvent),
210    ContentBlockStart(ContentBlockStartEvent),
211    ContentBlockDelta(ContentBlockDeltaEvent),
212    ContentBlockStop(ContentBlockStopEvent),
213    MessageStop(MessageStopEvent),
214}
215
216#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
217#[serde(tag = "type", rename_all = "snake_case")]
218pub enum AuthSource {
219    None,
220    ApiKey(String),
221    BearerToken(String),
222    ApiKeyAndBearer { api_key: String, bearer_token: String },
223}
224
225impl AuthSource {
226    pub fn from_env_or_saved() -> Result<Self, crate::error::ApiError> {
227        if let Some(api_key) = crate::client::read_env_non_empty("TERNLANG_API_KEY")? {
228            return match crate::client::read_env_non_empty("TERNLANG_AUTH_TOKEN")? {
229                Some(bearer_token) => Ok(Self::ApiKeyAndBearer {
230                    api_key,
231                    bearer_token,
232                }),
233                None => Ok(Self::ApiKey(api_key)),
234            };
235        }
236        Ok(Self::None)
237    }
238
239    pub fn apply(&self, provider: crate::client::LlmProvider, rb: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
240        use crate::client::LlmProvider;
241        match (provider, self) {
242            (LlmProvider::Google, _) => rb,
243            (LlmProvider::Anthropic, Self::ApiKey(key)) => rb.header("x-api-key", key),
244            (LlmProvider::OpenAi | LlmProvider::Ollama | LlmProvider::Xai, Self::ApiKey(key)) => rb.bearer_auth(key),
245            (_, Self::BearerToken(token)) => rb.bearer_auth(token),
246            (_, Self::ApiKeyAndBearer { api_key, bearer_token }) => {
247                rb.header("x-api-key", api_key).bearer_auth(bearer_token)
248            }
249            _ => rb,
250        }
251    }
252
253    pub fn api_key(&self) -> Option<&str> {
254        match self {
255            Self::ApiKey(key) => Some(key),
256            Self::ApiKeyAndBearer { api_key, .. } => Some(api_key),
257            _ => None,
258        }
259    }
260}
261
262#[derive(Debug, Clone, Serialize, Deserialize)]
263pub struct OAuthTokenSet {
264    pub access_token: String,
265    pub refresh_token: Option<String>,
266    pub expires_at: Option<u64>,
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct RuntimeTokenSet {
271    pub access_token: String,
272    pub refresh_token: Option<String>,
273    pub expires_at: Option<u64>,
274    pub scopes: Vec<String>,
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize)]
278pub struct OAuthTokenExchangeRequest {
279    pub code: String,
280    pub redirect_uri: String,
281}