use futures::Stream;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use tracing::debug;
use super::config::SnowflakeConfig;
use super::error::SnowflakeError;
use super::model_info::get_available_models;
use crate::core::providers::base::GlobalPoolManager;
use crate::core::providers::unified_provider::ProviderError;
use crate::core::traits::error_mapper::trait_def::ErrorMapper;
use crate::core::traits::provider::ProviderConfig as _;
use crate::core::traits::provider::llm_provider::trait_definition::LLMProvider;
use crate::core::types::health::HealthStatus;
use crate::core::types::responses::{ChatChunk, ChatResponse, EmbeddingResponse};
use crate::core::types::{chat::ChatRequest, embedding::EmbeddingRequest};
use crate::core::types::{context::RequestContext, model::ModelInfo, model::ProviderCapability};
const SNOWFLAKE_CAPABILITIES: &[ProviderCapability] = &[
ProviderCapability::ChatCompletion,
ProviderCapability::ChatCompletionStream,
ProviderCapability::ToolCalling,
];
#[derive(Debug, Clone)]
pub struct SnowflakeProvider {
config: SnowflakeConfig,
_pool_manager: Arc<GlobalPoolManager>,
models: Vec<ModelInfo>,
}
impl SnowflakeProvider {
pub async fn new(config: SnowflakeConfig) -> Result<Self, SnowflakeError> {
config
.validate()
.map_err(|e| SnowflakeError::configuration("snowflake", e))?;
let pool_manager = Arc::new(GlobalPoolManager::new().map_err(|e| {
SnowflakeError::configuration(
"snowflake",
format!("Failed to create pool manager: {}", e),
)
})?);
let models = get_available_models()
.iter()
.map(|info| {
let mut capabilities = vec![
ProviderCapability::ChatCompletion,
ProviderCapability::ChatCompletionStream,
];
if info.supports_tools {
capabilities.push(ProviderCapability::ToolCalling);
}
ModelInfo {
id: info.model_id.to_string(),
name: info.display_name.to_string(),
provider: "snowflake".to_string(),
max_context_length: info.max_context_length as u32,
max_output_length: Some(info.max_output_length as u32),
supports_streaming: true,
supports_tools: info.supports_tools,
supports_multimodal: false,
input_cost_per_1k_tokens: None, output_cost_per_1k_tokens: None,
currency: "USD".to_string(),
capabilities,
created_at: None,
updated_at: None,
metadata: HashMap::new(),
}
})
.collect();
Ok(Self {
config,
_pool_manager: pool_manager,
models,
})
}
pub async fn with_api_key(
api_key: impl Into<String>,
account_id: impl Into<String>,
) -> Result<Self, SnowflakeError> {
let config = SnowflakeConfig {
api_key: Some(api_key.into()),
account_id: Some(account_id.into()),
..Default::default()
};
Self::new(config).await
}
fn get_api_base(&self) -> String {
if let Some(base) = &self.config.api_base {
base.clone()
} else if let Some(account_id) = &self.config.account_id {
format!("https://{}.snowflakecomputing.com/api/v2", account_id)
} else {
std::env::var("SNOWFLAKE_ACCOUNT_ID")
.map(|id| format!("https://{}.snowflakecomputing.com/api/v2", id))
.unwrap_or_else(|_| "https://snowflakecomputing.com/api/v2".to_string())
}
}
fn get_api_key(&self) -> Option<String> {
self.config
.api_key
.clone()
.or_else(|| std::env::var("SNOWFLAKE_JWT").ok())
}
}
impl LLMProvider for SnowflakeProvider {
fn name(&self) -> &'static str {
"snowflake"
}
fn capabilities(&self) -> &'static [ProviderCapability] {
SNOWFLAKE_CAPABILITIES
}
fn models(&self) -> &[ModelInfo] {
&self.models
}
fn get_supported_openai_params(&self, _model: &str) -> &'static [&'static str] {
static PARAMS: &[&str] = &[
"stream",
"max_tokens",
"temperature",
"top_p",
"stop",
"tools",
"tool_choice",
];
PARAMS
}
async fn map_openai_params(
&self,
mut params: HashMap<String, serde_json::Value>,
_model: &str,
) -> Result<HashMap<String, serde_json::Value>, ProviderError> {
if let Some(max_completion_tokens) = params.remove("max_completion_tokens") {
params.insert("max_tokens".to_string(), max_completion_tokens);
}
Ok(params)
}
async fn transform_request(
&self,
request: ChatRequest,
_context: RequestContext,
) -> Result<serde_json::Value, ProviderError> {
serde_json::to_value(&request)
.map_err(|e| SnowflakeError::invalid_request("snowflake", e.to_string()))
}
async fn transform_response(
&self,
raw_response: &[u8],
_model: &str,
_request_id: &str,
) -> Result<ChatResponse, ProviderError> {
serde_json::from_slice(raw_response).map_err(|e| {
SnowflakeError::api_error("snowflake", 500, format!("Failed to parse response: {}", e))
})
}
fn get_error_mapper(&self) -> Box<dyn ErrorMapper<ProviderError>> {
Box::new(crate::core::traits::error_mapper::DefaultErrorMapper)
}
async fn chat_completion(
&self,
request: ChatRequest,
_context: RequestContext,
) -> Result<ChatResponse, ProviderError> {
debug!("Snowflake chat request: model={}", request.model);
let api_key = self
.get_api_key()
.ok_or_else(|| SnowflakeError::authentication("snowflake", "API key is required"))?;
let url = format!("{}/cortex/inference:complete", self.get_api_base());
let body = serde_json::json!({
"model": request.model.strip_prefix("snowflake/").unwrap_or(&request.model),
"messages": request.messages,
"temperature": request.temperature.unwrap_or(0.7),
"max_tokens": request.max_tokens.unwrap_or(1024),
"top_p": request.top_p.unwrap_or(1.0),
});
let client = reqwest::Client::new();
let response = client
.post(&url)
.header("Authorization", format!("Snowflake Token=\"{}\"", api_key))
.header("Content-Type", "application/json")
.header("X-Snowflake-Authorization-Token-Type", "KEYPAIR_JWT")
.json(&body)
.send()
.await
.map_err(|e| SnowflakeError::network("snowflake", e.to_string()))?;
let status = response.status();
let response_bytes = response
.bytes()
.await
.map_err(|e| SnowflakeError::network("snowflake", e.to_string()))?;
if !status.is_success() {
let body_str = String::from_utf8_lossy(&response_bytes);
return Err(SnowflakeError::api_error(
"snowflake",
status.as_u16(),
body_str.to_string(),
));
}
let json: serde_json::Value = serde_json::from_slice(&response_bytes).map_err(|e| {
SnowflakeError::api_error("snowflake", 500, format!("Failed to parse response: {}", e))
})?;
let content = json
.get("choices")
.and_then(|c| c.get(0))
.and_then(|c| c.get("messages"))
.and_then(|m| m.as_str())
.unwrap_or("");
Ok(ChatResponse {
id: format!("snowflake-{}", uuid::Uuid::new_v4().simple()),
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp(),
model: format!("snowflake/{}", request.model),
choices: vec![crate::core::types::responses::ChatChoice {
index: 0,
message: crate::core::types::chat::ChatMessage {
role: crate::core::types::message::MessageRole::Assistant,
content: Some(crate::core::types::message::MessageContent::Text(
content.to_string(),
)),
thinking: None,
name: None,
tool_calls: None,
tool_call_id: None,
function_call: None,
},
finish_reason: Some(crate::core::types::responses::FinishReason::Stop),
logprobs: None,
}],
usage: None,
system_fingerprint: None,
})
}
async fn chat_completion_stream(
&self,
_request: ChatRequest,
_context: RequestContext,
) -> Result<Pin<Box<dyn Stream<Item = Result<ChatChunk, ProviderError>> + Send>>, ProviderError>
{
Err(SnowflakeError::not_supported(
"snowflake",
"Streaming not yet implemented for Snowflake",
))
}
async fn embeddings(
&self,
_request: EmbeddingRequest,
_context: RequestContext,
) -> Result<EmbeddingResponse, ProviderError> {
Err(SnowflakeError::not_supported(
"snowflake",
"Embeddings not supported by Snowflake Cortex provider",
))
}
async fn health_check(&self) -> HealthStatus {
if self.config.validate().is_ok() {
HealthStatus::Healthy
} else {
HealthStatus::Unhealthy
}
}
async fn calculate_cost(
&self,
_model: &str,
_input_tokens: u32,
_output_tokens: u32,
) -> Result<f64, ProviderError> {
Ok(0.0)
}
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
#[test]
fn test_snowflake_provider_name() {
}
}