Skip to main content

aster/providers/
snowflake.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use serde_json::{json, Value};
5
6use super::api_client::{ApiClient, AuthMethod};
7use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
8use super::errors::ProviderError;
9use super::formats::snowflake::{create_request, get_usage, response_to_message};
10use super::retry::ProviderRetry;
11use super::utils::{get_model, map_http_error_to_provider_error, ImageFormat, RequestLog};
12use crate::config::ConfigError;
13use crate::conversation::message::Message;
14
15use crate::model::ModelConfig;
16use rmcp::model::Tool;
17
18pub const SNOWFLAKE_DEFAULT_MODEL: &str = "claude-sonnet-4-5";
19pub const SNOWFLAKE_KNOWN_MODELS: &[&str] = &[
20    // Claude 4.5 series
21    "claude-sonnet-4-5",
22    "claude-haiku-4-5",
23    // Claude 4 series
24    "claude-4-sonnet",
25    "claude-4-opus",
26    // Claude 3 series
27    "claude-3-7-sonnet",
28    "claude-3-5-sonnet",
29];
30
31pub const SNOWFLAKE_DOC_URL: &str =
32    "https://docs.snowflake.com/user-guide/snowflake-cortex/aisql#choosing-a-model";
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub enum SnowflakeAuth {
36    Token(String),
37}
38
39impl SnowflakeAuth {
40    pub fn token(token: String) -> Self {
41        Self::Token(token)
42    }
43}
44
45#[derive(Debug, serde::Serialize)]
46pub struct SnowflakeProvider {
47    #[serde(skip)]
48    api_client: ApiClient,
49    model: ModelConfig,
50    image_format: ImageFormat,
51    #[serde(skip)]
52    name: String,
53}
54
55impl SnowflakeProvider {
56    pub async fn from_env(model: ModelConfig) -> Result<Self> {
57        let config = crate::config::Config::global();
58        let mut host: Result<String, ConfigError> = config.get_param("SNOWFLAKE_HOST");
59        if host.is_err() {
60            host = config.get_secret("SNOWFLAKE_HOST")
61        }
62        if host.is_err() {
63            return Err(ConfigError::NotFound(
64                "Did not find SNOWFLAKE_HOST in either config file or keyring".to_string(),
65            )
66            .into());
67        }
68
69        let mut host = host?;
70
71        // Convert host to lowercase
72        host = host.to_lowercase();
73
74        // Ensure host ends with snowflakecomputing.com
75        if !host.ends_with("snowflakecomputing.com") {
76            host = format!("{}.snowflakecomputing.com", host);
77        }
78
79        let mut token: Result<String, ConfigError> = config.get_param("SNOWFLAKE_TOKEN");
80
81        if token.is_err() {
82            token = config.get_secret("SNOWFLAKE_TOKEN")
83        }
84
85        if token.is_err() {
86            return Err(ConfigError::NotFound(
87                "Did not find SNOWFLAKE_TOKEN in either config file or keyring".to_string(),
88            )
89            .into());
90        }
91
92        // Ensure host has https:// prefix
93        let base_url = if !host.starts_with("https://") && !host.starts_with("http://") {
94            format!("https://{}", host)
95        } else {
96            host
97        };
98
99        let auth = AuthMethod::BearerToken(token?);
100        let api_client = ApiClient::new(base_url, auth)?.with_header("User-Agent", "aster")?;
101
102        Ok(Self {
103            api_client,
104            model,
105            image_format: ImageFormat::OpenAi,
106            name: Self::metadata().name,
107        })
108    }
109
110    async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
111        let response = self
112            .api_client
113            .response_post("api/v2/cortex/inference:complete", payload)
114            .await?;
115
116        let status = response.status();
117        let payload_text: String = response.text().await.ok().unwrap_or_default();
118
119        if status.is_success() {
120            if let Ok(payload) = serde_json::from_str::<Value>(&payload_text) {
121                if payload.get("code").is_some() {
122                    let code = payload
123                        .get("code")
124                        .and_then(|c| c.as_str())
125                        .unwrap_or("Unknown code");
126                    let message = payload
127                        .get("message")
128                        .and_then(|m| m.as_str())
129                        .unwrap_or("Unknown message");
130                    return Err(ProviderError::RequestFailed(format!(
131                        "{} - {}",
132                        code, message
133                    )));
134                }
135            }
136        }
137
138        let lines = payload_text.lines().collect::<Vec<_>>();
139
140        let mut text = String::new();
141        let mut tool_name = String::new();
142        let mut tool_input = String::new();
143        let mut tool_use_id = String::new();
144        for line in lines.iter() {
145            if line.is_empty() {
146                continue;
147            }
148
149            let json_str = match line.strip_prefix("data: ") {
150                Some(s) => s,
151                None => continue,
152            };
153
154            if let Ok(json_line) = serde_json::from_str::<Value>(json_str) {
155                let choices = match json_line.get("choices").and_then(|c| c.as_array()) {
156                    Some(choices) => choices,
157                    None => {
158                        continue;
159                    }
160                };
161
162                let choice = match choices.first() {
163                    Some(choice) => choice,
164                    None => {
165                        continue;
166                    }
167                };
168
169                let delta = match choice.get("delta") {
170                    Some(delta) => delta,
171                    None => {
172                        continue;
173                    }
174                };
175
176                // Track if we found text in content_list to avoid duplication
177                let mut found_text_in_content_list = false;
178
179                // Handle content_list array first
180                if let Some(content_list) = delta.get("content_list").and_then(|cl| cl.as_array()) {
181                    for content_item in content_list {
182                        match content_item.get("type").and_then(|t| t.as_str()) {
183                            Some("text") => {
184                                if let Some(text_content) =
185                                    content_item.get("text").and_then(|t| t.as_str())
186                                {
187                                    text.push_str(text_content);
188                                    found_text_in_content_list = true;
189                                }
190                            }
191                            Some("tool_use") => {
192                                if let Some(tool_id) =
193                                    content_item.get("tool_use_id").and_then(|id| id.as_str())
194                                {
195                                    tool_use_id.push_str(tool_id);
196                                }
197                                if let Some(name) =
198                                    content_item.get("name").and_then(|n| n.as_str())
199                                {
200                                    tool_name.push_str(name);
201                                }
202                                if let Some(input) =
203                                    content_item.get("input").and_then(|i| i.as_str())
204                                {
205                                    tool_input.push_str(input);
206                                }
207                            }
208                            _ => {
209                                // Handle content items without explicit type but with tool information
210                                if let Some(name) =
211                                    content_item.get("name").and_then(|n| n.as_str())
212                                {
213                                    tool_name.push_str(name);
214                                }
215                                if let Some(tool_id) =
216                                    content_item.get("tool_use_id").and_then(|id| id.as_str())
217                                {
218                                    tool_use_id.push_str(tool_id);
219                                }
220                                if let Some(input) =
221                                    content_item.get("input").and_then(|i| i.as_str())
222                                {
223                                    tool_input.push_str(input);
224                                }
225                            }
226                        }
227                    }
228                }
229
230                // Handle direct content field (for text) only if we didn't find text in content_list
231                if !found_text_in_content_list {
232                    if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
233                        text.push_str(content);
234                    }
235                }
236            }
237        }
238
239        // Build the appropriate response structure
240        let mut content_list = Vec::new();
241
242        // Add text content if available
243        if !text.is_empty() {
244            content_list.push(json!({
245                "type": "text",
246                "text": text
247            }));
248        }
249
250        // Add tool use content only if we have complete tool information
251        if !tool_use_id.is_empty() && !tool_name.is_empty() {
252            // Parse tool input as JSON if it's not empty
253            let parsed_input = if tool_input.is_empty() {
254                json!({})
255            } else {
256                serde_json::from_str::<Value>(&tool_input)
257                    .unwrap_or_else(|_| json!({"raw_input": tool_input}))
258            };
259
260            content_list.push(json!({
261                "type": "tool_use",
262                "tool_use_id": tool_use_id,
263                "name": tool_name,
264                "input": parsed_input
265            }));
266        }
267
268        // Ensure we always have at least some content
269        if content_list.is_empty() {
270            content_list.push(json!({
271                "type": "text",
272                "text": ""
273            }));
274        }
275
276        let answer_payload = json!({
277            "role": "assistant",
278            "content": text,
279            "content_list": content_list
280        });
281
282        if status.is_success() {
283            Ok(answer_payload)
284        } else {
285            let error_json = serde_json::from_str::<Value>(&payload_text).ok();
286            Err(map_http_error_to_provider_error(status, error_json))
287        }
288    }
289}
290
291#[async_trait]
292impl Provider for SnowflakeProvider {
293    fn metadata() -> ProviderMetadata {
294        ProviderMetadata::new(
295            "snowflake",
296            "Snowflake",
297            "Access the latest models using Snowflake Cortex services.",
298            SNOWFLAKE_DEFAULT_MODEL,
299            SNOWFLAKE_KNOWN_MODELS.to_vec(),
300            SNOWFLAKE_DOC_URL,
301            vec![
302                ConfigKey::new("SNOWFLAKE_HOST", true, false, None),
303                ConfigKey::new("SNOWFLAKE_TOKEN", true, true, None),
304            ],
305        )
306    }
307
308    fn get_name(&self) -> &str {
309        &self.name
310    }
311
312    fn get_model_config(&self) -> ModelConfig {
313        self.model.clone()
314    }
315
316    #[tracing::instrument(
317        skip(self, model_config, system, messages, tools),
318        fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
319    )]
320    async fn complete_with_model(
321        &self,
322        model_config: &ModelConfig,
323        system: &str,
324        messages: &[Message],
325        tools: &[Tool],
326    ) -> Result<(Message, ProviderUsage), ProviderError> {
327        let payload = create_request(model_config, system, messages, tools)?;
328
329        let mut log = RequestLog::start(&self.model, &payload)?;
330
331        let response = self
332            .with_retry(|| async {
333                let payload_clone = payload.clone();
334                self.post(&payload_clone).await
335            })
336            .await?;
337
338        let message = response_to_message(&response)?;
339        let usage = get_usage(&response)?;
340        let response_model = get_model(&response);
341
342        log.write(&response, Some(&usage))?;
343
344        Ok((message, ProviderUsage::new(response_model, usage)))
345    }
346}