1use anyhow::Result;
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use serde_json::{json, Value};
5
6use super::api_client::{ApiClient, AuthMethod};
7use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
8use super::errors::ProviderError;
9use super::formats::snowflake::{create_request, get_usage, response_to_message};
10use super::retry::ProviderRetry;
11use super::utils::{get_model, map_http_error_to_provider_error, ImageFormat, RequestLog};
12use crate::config::ConfigError;
13use crate::conversation::message::Message;
14
15use crate::model::ModelConfig;
16use rmcp::model::Tool;
17
18pub const SNOWFLAKE_DEFAULT_MODEL: &str = "claude-sonnet-4-5";
19pub const SNOWFLAKE_KNOWN_MODELS: &[&str] = &[
20 "claude-sonnet-4-5",
22 "claude-haiku-4-5",
23 "claude-4-sonnet",
25 "claude-4-opus",
26 "claude-3-7-sonnet",
28 "claude-3-5-sonnet",
29];
30
31pub const SNOWFLAKE_DOC_URL: &str =
32 "https://docs.snowflake.com/user-guide/snowflake-cortex/aisql#choosing-a-model";
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub enum SnowflakeAuth {
36 Token(String),
37}
38
39impl SnowflakeAuth {
40 pub fn token(token: String) -> Self {
41 Self::Token(token)
42 }
43}
44
45#[derive(Debug, serde::Serialize)]
46pub struct SnowflakeProvider {
47 #[serde(skip)]
48 api_client: ApiClient,
49 model: ModelConfig,
50 image_format: ImageFormat,
51 #[serde(skip)]
52 name: String,
53}
54
55impl SnowflakeProvider {
56 pub async fn from_env(model: ModelConfig) -> Result<Self> {
57 let config = crate::config::Config::global();
58 let mut host: Result<String, ConfigError> = config.get_param("SNOWFLAKE_HOST");
59 if host.is_err() {
60 host = config.get_secret("SNOWFLAKE_HOST")
61 }
62 if host.is_err() {
63 return Err(ConfigError::NotFound(
64 "Did not find SNOWFLAKE_HOST in either config file or keyring".to_string(),
65 )
66 .into());
67 }
68
69 let mut host = host?;
70
71 host = host.to_lowercase();
73
74 if !host.ends_with("snowflakecomputing.com") {
76 host = format!("{}.snowflakecomputing.com", host);
77 }
78
79 let mut token: Result<String, ConfigError> = config.get_param("SNOWFLAKE_TOKEN");
80
81 if token.is_err() {
82 token = config.get_secret("SNOWFLAKE_TOKEN")
83 }
84
85 if token.is_err() {
86 return Err(ConfigError::NotFound(
87 "Did not find SNOWFLAKE_TOKEN in either config file or keyring".to_string(),
88 )
89 .into());
90 }
91
92 let base_url = if !host.starts_with("https://") && !host.starts_with("http://") {
94 format!("https://{}", host)
95 } else {
96 host
97 };
98
99 let auth = AuthMethod::BearerToken(token?);
100 let api_client = ApiClient::new(base_url, auth)?.with_header("User-Agent", "aster")?;
101
102 Ok(Self {
103 api_client,
104 model,
105 image_format: ImageFormat::OpenAi,
106 name: Self::metadata().name,
107 })
108 }
109
110 async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
111 let response = self
112 .api_client
113 .response_post("api/v2/cortex/inference:complete", payload)
114 .await?;
115
116 let status = response.status();
117 let payload_text: String = response.text().await.ok().unwrap_or_default();
118
119 if status.is_success() {
120 if let Ok(payload) = serde_json::from_str::<Value>(&payload_text) {
121 if payload.get("code").is_some() {
122 let code = payload
123 .get("code")
124 .and_then(|c| c.as_str())
125 .unwrap_or("Unknown code");
126 let message = payload
127 .get("message")
128 .and_then(|m| m.as_str())
129 .unwrap_or("Unknown message");
130 return Err(ProviderError::RequestFailed(format!(
131 "{} - {}",
132 code, message
133 )));
134 }
135 }
136 }
137
138 let lines = payload_text.lines().collect::<Vec<_>>();
139
140 let mut text = String::new();
141 let mut tool_name = String::new();
142 let mut tool_input = String::new();
143 let mut tool_use_id = String::new();
144 for line in lines.iter() {
145 if line.is_empty() {
146 continue;
147 }
148
149 let json_str = match line.strip_prefix("data: ") {
150 Some(s) => s,
151 None => continue,
152 };
153
154 if let Ok(json_line) = serde_json::from_str::<Value>(json_str) {
155 let choices = match json_line.get("choices").and_then(|c| c.as_array()) {
156 Some(choices) => choices,
157 None => {
158 continue;
159 }
160 };
161
162 let choice = match choices.first() {
163 Some(choice) => choice,
164 None => {
165 continue;
166 }
167 };
168
169 let delta = match choice.get("delta") {
170 Some(delta) => delta,
171 None => {
172 continue;
173 }
174 };
175
176 let mut found_text_in_content_list = false;
178
179 if let Some(content_list) = delta.get("content_list").and_then(|cl| cl.as_array()) {
181 for content_item in content_list {
182 match content_item.get("type").and_then(|t| t.as_str()) {
183 Some("text") => {
184 if let Some(text_content) =
185 content_item.get("text").and_then(|t| t.as_str())
186 {
187 text.push_str(text_content);
188 found_text_in_content_list = true;
189 }
190 }
191 Some("tool_use") => {
192 if let Some(tool_id) =
193 content_item.get("tool_use_id").and_then(|id| id.as_str())
194 {
195 tool_use_id.push_str(tool_id);
196 }
197 if let Some(name) =
198 content_item.get("name").and_then(|n| n.as_str())
199 {
200 tool_name.push_str(name);
201 }
202 if let Some(input) =
203 content_item.get("input").and_then(|i| i.as_str())
204 {
205 tool_input.push_str(input);
206 }
207 }
208 _ => {
209 if let Some(name) =
211 content_item.get("name").and_then(|n| n.as_str())
212 {
213 tool_name.push_str(name);
214 }
215 if let Some(tool_id) =
216 content_item.get("tool_use_id").and_then(|id| id.as_str())
217 {
218 tool_use_id.push_str(tool_id);
219 }
220 if let Some(input) =
221 content_item.get("input").and_then(|i| i.as_str())
222 {
223 tool_input.push_str(input);
224 }
225 }
226 }
227 }
228 }
229
230 if !found_text_in_content_list {
232 if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
233 text.push_str(content);
234 }
235 }
236 }
237 }
238
239 let mut content_list = Vec::new();
241
242 if !text.is_empty() {
244 content_list.push(json!({
245 "type": "text",
246 "text": text
247 }));
248 }
249
250 if !tool_use_id.is_empty() && !tool_name.is_empty() {
252 let parsed_input = if tool_input.is_empty() {
254 json!({})
255 } else {
256 serde_json::from_str::<Value>(&tool_input)
257 .unwrap_or_else(|_| json!({"raw_input": tool_input}))
258 };
259
260 content_list.push(json!({
261 "type": "tool_use",
262 "tool_use_id": tool_use_id,
263 "name": tool_name,
264 "input": parsed_input
265 }));
266 }
267
268 if content_list.is_empty() {
270 content_list.push(json!({
271 "type": "text",
272 "text": ""
273 }));
274 }
275
276 let answer_payload = json!({
277 "role": "assistant",
278 "content": text,
279 "content_list": content_list
280 });
281
282 if status.is_success() {
283 Ok(answer_payload)
284 } else {
285 let error_json = serde_json::from_str::<Value>(&payload_text).ok();
286 Err(map_http_error_to_provider_error(status, error_json))
287 }
288 }
289}
290
291#[async_trait]
292impl Provider for SnowflakeProvider {
293 fn metadata() -> ProviderMetadata {
294 ProviderMetadata::new(
295 "snowflake",
296 "Snowflake",
297 "Access the latest models using Snowflake Cortex services.",
298 SNOWFLAKE_DEFAULT_MODEL,
299 SNOWFLAKE_KNOWN_MODELS.to_vec(),
300 SNOWFLAKE_DOC_URL,
301 vec![
302 ConfigKey::new("SNOWFLAKE_HOST", true, false, None),
303 ConfigKey::new("SNOWFLAKE_TOKEN", true, true, None),
304 ],
305 )
306 }
307
308 fn get_name(&self) -> &str {
309 &self.name
310 }
311
312 fn get_model_config(&self) -> ModelConfig {
313 self.model.clone()
314 }
315
316 #[tracing::instrument(
317 skip(self, model_config, system, messages, tools),
318 fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
319 )]
320 async fn complete_with_model(
321 &self,
322 model_config: &ModelConfig,
323 system: &str,
324 messages: &[Message],
325 tools: &[Tool],
326 ) -> Result<(Message, ProviderUsage), ProviderError> {
327 let payload = create_request(model_config, system, messages, tools)?;
328
329 let mut log = RequestLog::start(&self.model, &payload)?;
330
331 let response = self
332 .with_retry(|| async {
333 let payload_clone = payload.clone();
334 self.post(&payload_clone).await
335 })
336 .await?;
337
338 let message = response_to_message(&response)?;
339 let usage = get_usage(&response)?;
340 let response_model = get_model(&response);
341
342 log.write(&response, Some(&usage))?;
343
344 Ok((message, ProviderUsage::new(response_model, usage)))
345 }
346}