use std::collections::HashMap;
use crate::types::content::{Message, SystemContentBlock};
use crate::types::errors::StrandsError;
use crate::types::tools::{ToolChoice, ToolSpec};
use super::{Model, ModelConfig, StreamEventStream};
#[derive(Debug, Clone, Default)]
pub struct SageMakerEndpointConfig {
pub endpoint_name: String,
pub region_name: Option<String>,
pub inference_component_name: Option<String>,
pub target_model: Option<String>,
pub target_variant: Option<String>,
pub additional_args: Option<HashMap<String, serde_json::Value>>,
}
impl SageMakerEndpointConfig {
pub fn new(endpoint_name: impl Into<String>) -> Self {
Self {
endpoint_name: endpoint_name.into(),
..Default::default()
}
}
pub fn with_region(mut self, region: impl Into<String>) -> Self {
self.region_name = Some(region.into());
self
}
pub fn with_inference_component(mut self, component: impl Into<String>) -> Self {
self.inference_component_name = Some(component.into());
self
}
pub fn with_target_model(mut self, model: impl Into<String>) -> Self {
self.target_model = Some(model.into());
self
}
}
#[derive(Debug, Clone, Default)]
pub struct SageMakerPayloadConfig {
pub max_tokens: Option<u32>,
pub stream: bool,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub top_k: Option<u32>,
pub stop: Option<Vec<String>>,
pub tool_results_as_user_messages: bool,
pub additional_args: Option<HashMap<String, serde_json::Value>>,
}
impl SageMakerPayloadConfig {
pub fn new() -> Self {
Self {
stream: true,
..Default::default()
}
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_stream(mut self, stream: bool) -> Self {
self.stream = stream;
self
}
}
pub struct SageMakerModel {
config: ModelConfig,
endpoint_config: SageMakerEndpointConfig,
payload_config: SageMakerPayloadConfig,
}
impl SageMakerModel {
pub fn new(endpoint_config: SageMakerEndpointConfig, payload_config: SageMakerPayloadConfig) -> Self {
Self {
config: ModelConfig::new(&endpoint_config.endpoint_name),
endpoint_config,
payload_config,
}
}
pub fn endpoint_config(&self) -> &SageMakerEndpointConfig {
&self.endpoint_config
}
pub fn payload_config(&self) -> &SageMakerPayloadConfig {
&self.payload_config
}
pub fn update_endpoint_config(&mut self, config: SageMakerEndpointConfig) {
self.config = ModelConfig::new(&config.endpoint_name);
self.endpoint_config = config;
}
pub fn update_payload_config(&mut self, config: SageMakerPayloadConfig) {
self.payload_config = config;
}
}
impl Model for SageMakerModel {
fn config(&self) -> &ModelConfig {
&self.config
}
fn update_config(&mut self, config: ModelConfig) {
self.config = config;
}
fn stream<'a>(
&'a self,
_messages: &'a [Message],
_tool_specs: Option<&'a [ToolSpec]>,
_system_prompt: Option<&'a str>,
_tool_choice: Option<ToolChoice>,
_system_prompt_content: Option<&'a [SystemContentBlock]>,
) -> StreamEventStream<'a> {
Box::pin(futures::stream::once(async {
Err(StrandsError::ModelError {
message: "SageMaker integration requires aws-sdk-sagemakerruntime implementation".into(),
source: None,
})
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sagemaker_endpoint_config() {
let config = SageMakerEndpointConfig::new("my-endpoint")
.with_region("us-west-2")
.with_target_model("my-model");
assert_eq!(config.endpoint_name, "my-endpoint");
assert_eq!(config.region_name, Some("us-west-2".to_string()));
assert_eq!(config.target_model, Some("my-model".to_string()));
}
#[test]
fn test_sagemaker_payload_config() {
let config = SageMakerPayloadConfig::new()
.with_max_tokens(1000)
.with_temperature(0.7);
assert_eq!(config.max_tokens, Some(1000));
assert_eq!(config.temperature, Some(0.7));
assert!(config.stream);
}
#[test]
fn test_sagemaker_model_creation() {
let endpoint_config = SageMakerEndpointConfig::new("test-endpoint");
let payload_config = SageMakerPayloadConfig::new();
let model = SageMakerModel::new(endpoint_config, payload_config);
assert_eq!(model.config().model_id, "test-endpoint");
}
}