use futures::Stream;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use tracing::debug;
use super::config::SagemakerConfig;
use super::error::{SagemakerError, SagemakerErrorMapper};
use super::sigv4::SagemakerSigV4Signer;
use crate::core::providers::base::{GlobalPoolManager, HttpErrorMapper};
use crate::core::providers::unified_provider::ProviderError;
use crate::core::traits::error_mapper::trait_def::ErrorMapper;
use crate::core::traits::provider::ProviderConfig as _;
use crate::core::traits::provider::llm_provider::trait_definition::LLMProvider;
use crate::core::types::health::HealthStatus;
use crate::core::types::responses::{ChatChunk, ChatResponse, EmbeddingResponse};
use crate::core::types::{chat::ChatRequest, embedding::EmbeddingRequest};
use crate::core::types::{context::RequestContext, model::ModelInfo, model::ProviderCapability};
const SAGEMAKER_CAPABILITIES: &[ProviderCapability] = &[
ProviderCapability::ChatCompletion,
ProviderCapability::ChatCompletionStream,
];
#[derive(Debug, Clone)]
pub struct SagemakerProvider {
config: SagemakerConfig,
_pool_manager: Arc<GlobalPoolManager>,
signer: SagemakerSigV4Signer,
models: Vec<ModelInfo>,
}
impl SagemakerProvider {
pub async fn new(config: SagemakerConfig) -> Result<Self, SagemakerError> {
config
.validate()
.map_err(|e| ProviderError::configuration("sagemaker", e))?;
let pool_manager = Arc::new(GlobalPoolManager::new().map_err(|e| {
ProviderError::configuration(
"sagemaker",
format!("Failed to create pool manager: {}", e),
)
})?);
let signer = SagemakerSigV4Signer::new(
config.get_access_key_id().unwrap_or_default(),
config.get_secret_access_key().unwrap_or_default(),
config.get_session_token(),
config.get_region(),
);
let models = vec![ModelInfo {
id: "sagemaker-endpoint".to_string(),
name: "Sagemaker Endpoint".to_string(),
provider: "sagemaker".to_string(),
max_context_length: 4096,
max_output_length: Some(4096),
supports_streaming: true,
supports_tools: false,
supports_multimodal: false,
input_cost_per_1k_tokens: None,
output_cost_per_1k_tokens: None,
currency: "USD".to_string(),
capabilities: vec![
ProviderCapability::ChatCompletion,
ProviderCapability::ChatCompletionStream,
],
created_at: None,
updated_at: None,
metadata: HashMap::new(),
}];
Ok(Self {
config,
_pool_manager: pool_manager,
signer,
models,
})
}
pub async fn with_credentials(
access_key_id: impl Into<String>,
secret_access_key: impl Into<String>,
region: impl Into<String>,
) -> Result<Self, SagemakerError> {
let config = SagemakerConfig {
aws_access_key_id: Some(access_key_id.into()),
aws_secret_access_key: Some(secret_access_key.into()),
aws_region: Some(region.into()),
..Default::default()
};
Self::new(config).await
}
}
impl LLMProvider for SagemakerProvider {
fn name(&self) -> &'static str {
"sagemaker"
}
fn capabilities(&self) -> &'static [ProviderCapability] {
SAGEMAKER_CAPABILITIES
}
fn models(&self) -> &[ModelInfo] {
&self.models
}
fn get_supported_openai_params(&self, _model: &str) -> &'static [&'static str] {
static PARAMS: &[&str] = &["stream", "max_tokens", "temperature", "top_p", "stop"];
PARAMS
}
async fn map_openai_params(
&self,
mut params: HashMap<String, serde_json::Value>,
_model: &str,
) -> Result<HashMap<String, serde_json::Value>, ProviderError> {
if let Some(max_completion_tokens) = params.remove("max_completion_tokens") {
params.insert("max_tokens".to_string(), max_completion_tokens);
}
if !self.config.allow_zero_temp
&& let Some(temp) = params.get("temperature")
&& temp.as_f64() == Some(0.0)
{
params.insert("temperature".to_string(), serde_json::json!(0.01));
}
Ok(params)
}
async fn transform_request(
&self,
request: ChatRequest,
_context: RequestContext,
) -> Result<serde_json::Value, ProviderError> {
serde_json::to_value(&request)
.map_err(|e| ProviderError::invalid_request("sagemaker", e.to_string()))
}
async fn transform_response(
&self,
raw_response: &[u8],
_model: &str,
_request_id: &str,
) -> Result<ChatResponse, ProviderError> {
serde_json::from_slice(raw_response).map_err(|e| {
ProviderError::response_parsing("sagemaker", format!("Failed to parse response: {}", e))
})
}
fn get_error_mapper(&self) -> Box<dyn ErrorMapper<ProviderError>> {
Box::new(SagemakerErrorMapper)
}
async fn chat_completion(
&self,
request: ChatRequest,
_context: RequestContext,
) -> Result<ChatResponse, ProviderError> {
debug!("Sagemaker chat request: model={}", request.model);
let endpoint_name = request
.model
.strip_prefix("sagemaker/")
.unwrap_or(&request.model);
let url = self.config.build_endpoint_url(endpoint_name, false);
let body = serde_json::json!({
"inputs": format_messages_for_tgi(&request),
"parameters": {
"max_new_tokens": request.max_tokens.unwrap_or(512),
"temperature": request.temperature.unwrap_or(0.7),
"top_p": request.top_p.unwrap_or(0.9),
"do_sample": true,
}
});
let body_str = serde_json::to_string(&body)
.map_err(|e| ProviderError::invalid_request("sagemaker", e.to_string()))?;
let headers = self
.signer
.sign_request(
"POST",
&url,
&std::collections::HashMap::new(),
&body_str,
chrono::Utc::now(),
)
.map_err(|e| {
ProviderError::authentication("sagemaker", format!("Signing error: {}", e))
})?;
let client = reqwest::Client::new();
let mut req_builder = client.post(&url);
for (key, value) in headers {
req_builder = req_builder.header(key, value);
}
let response = req_builder
.body(body_str)
.send()
.await
.map_err(|e| ProviderError::network("sagemaker", e.to_string()))?;
let status = response.status();
let response_bytes = response
.bytes()
.await
.map_err(|e| ProviderError::network("sagemaker", e.to_string()))?;
if !status.is_success() {
let body_str = String::from_utf8_lossy(&response_bytes);
return Err(match status.as_u16() {
400 => ProviderError::invalid_request("sagemaker", body_str.to_string()),
401 | 403 => ProviderError::authentication("sagemaker", body_str.to_string()),
404 | 424 => ProviderError::model_not_found("sagemaker", body_str.to_string()),
429 => ProviderError::rate_limit("sagemaker", None),
502 | 503 => {
ProviderError::api_error("sagemaker", status.as_u16(), body_str.to_string())
}
_ => HttpErrorMapper::map_status_code("sagemaker", status.as_u16(), &body_str),
});
}
parse_tgi_response(&response_bytes, &request.model)
}
async fn chat_completion_stream(
&self,
_request: ChatRequest,
_context: RequestContext,
) -> Result<Pin<Box<dyn Stream<Item = Result<ChatChunk, ProviderError>> + Send>>, ProviderError>
{
Err(ProviderError::not_supported(
"sagemaker",
"Streaming not yet implemented for Sagemaker".to_string(),
))
}
async fn embeddings(
&self,
_request: EmbeddingRequest,
_context: RequestContext,
) -> Result<EmbeddingResponse, ProviderError> {
Err(ProviderError::not_supported(
"sagemaker",
"Embeddings not supported by Sagemaker provider".to_string(),
))
}
async fn health_check(&self) -> HealthStatus {
if self.config.validate().is_ok() {
HealthStatus::Healthy
} else {
HealthStatus::Unhealthy
}
}
async fn calculate_cost(
&self,
_model: &str,
_input_tokens: u32,
_output_tokens: u32,
) -> Result<f64, ProviderError> {
Ok(0.0)
}
}
fn format_messages_for_tgi(request: &ChatRequest) -> String {
let mut prompt = String::new();
for message in &request.messages {
let role = match message.role {
crate::core::types::message::MessageRole::System => "System",
crate::core::types::message::MessageRole::User => "User",
crate::core::types::message::MessageRole::Assistant => "Assistant",
_ => "User",
};
if let Some(content) = &message.content {
let text = match content {
crate::core::types::message::MessageContent::Text(t) => t.clone(),
crate::core::types::message::MessageContent::Parts(parts) => parts
.iter()
.filter_map(|p| {
if let crate::core::types::content::ContentPart::Text { text } = p {
Some(text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join(" "),
};
prompt.push_str(&format!("{}: {}\n", role, text));
}
}
prompt.push_str("Assistant:");
prompt
}
fn parse_tgi_response(response_bytes: &[u8], model: &str) -> Result<ChatResponse, SagemakerError> {
let json: serde_json::Value = serde_json::from_slice(response_bytes).map_err(|e| {
ProviderError::response_parsing("sagemaker", format!("Failed to parse response: {}", e))
})?;
let generated_text = if let Some(arr) = json.as_array() {
arr.first()
.and_then(|v| v.get("generated_text"))
.and_then(|v| v.as_str())
.unwrap_or("")
} else {
json.get("generated_text")
.and_then(|v| v.as_str())
.unwrap_or("")
};
Ok(ChatResponse {
id: format!("sagemaker-{}", uuid::Uuid::new_v4().simple()),
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp(),
model: format!("sagemaker/{}", model),
choices: vec![crate::core::types::responses::ChatChoice {
index: 0,
message: crate::core::types::chat::ChatMessage {
role: crate::core::types::message::MessageRole::Assistant,
content: Some(crate::core::types::message::MessageContent::Text(
generated_text.to_string(),
)),
thinking: None,
name: None,
tool_calls: None,
tool_call_id: None,
function_call: None,
},
finish_reason: Some(crate::core::types::responses::FinishReason::Stop),
logprobs: None,
}],
usage: None,
system_fingerprint: None,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_format_messages_for_tgi() {
let request = ChatRequest {
model: "test".to_string(),
messages: vec![crate::core::types::chat::ChatMessage {
role: crate::core::types::message::MessageRole::User,
content: Some(crate::core::types::message::MessageContent::Text(
"Hello".to_string(),
)),
thinking: None,
name: None,
tool_calls: None,
tool_call_id: None,
function_call: None,
}],
..Default::default()
};
let prompt = format_messages_for_tgi(&request);
assert!(prompt.contains("User: Hello"));
assert!(prompt.ends_with("Assistant:"));
}
}