strands-agents 0.1.0

A Rust implementation of the Strands AI Agents SDK
Documentation
//! Amazon SageMaker model provider.
//!
//! This provider integrates with Amazon SageMaker endpoints.

use std::collections::HashMap;

use crate::types::content::{Message, SystemContentBlock};
use crate::types::errors::StrandsError;
use crate::types::tools::{ToolChoice, ToolSpec};

use super::{Model, ModelConfig, StreamEventStream};

/// Endpoint configuration for SageMaker.
#[derive(Debug, Clone, Default)]
pub struct SageMakerEndpointConfig {
    /// The name of the SageMaker endpoint to invoke.
    pub endpoint_name: String,
    /// AWS region name.
    pub region_name: Option<String>,
    /// The name of the inference component to use.
    pub inference_component_name: Option<String>,
    /// Target model for multi-model endpoints.
    pub target_model: Option<String>,
    /// Target variant.
    pub target_variant: Option<String>,
    /// Additional arguments for the request.
    pub additional_args: Option<HashMap<String, serde_json::Value>>,
}

impl SageMakerEndpointConfig {
    /// Create a new endpoint config.
    pub fn new(endpoint_name: impl Into<String>) -> Self {
        Self {
            endpoint_name: endpoint_name.into(),
            ..Default::default()
        }
    }

    /// Set region name.
    pub fn with_region(mut self, region: impl Into<String>) -> Self {
        self.region_name = Some(region.into());
        self
    }

    /// Set inference component name.
    pub fn with_inference_component(mut self, component: impl Into<String>) -> Self {
        self.inference_component_name = Some(component.into());
        self
    }

    /// Set target model.
    pub fn with_target_model(mut self, model: impl Into<String>) -> Self {
        self.target_model = Some(model.into());
        self
    }
}

/// Payload configuration for SageMaker.
#[derive(Debug, Clone, Default)]
pub struct SageMakerPayloadConfig {
    /// Maximum number of tokens to generate.
    pub max_tokens: Option<u32>,
    /// Whether to stream the response.
    pub stream: bool,
    /// Sampling temperature.
    pub temperature: Option<f64>,
    /// Top-p (nucleus sampling).
    pub top_p: Option<f64>,
    /// Top-k sampling.
    pub top_k: Option<u32>,
    /// Stop sequences.
    pub stop: Option<Vec<String>>,
    /// Convert tool results to user messages.
    pub tool_results_as_user_messages: bool,
    /// Additional arguments.
    pub additional_args: Option<HashMap<String, serde_json::Value>>,
}

impl SageMakerPayloadConfig {
    /// Create a new payload config.
    pub fn new() -> Self {
        Self {
            stream: true,
            ..Default::default()
        }
    }

    /// Set max tokens.
    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
        self.max_tokens = Some(max_tokens);
        self
    }

    /// Set temperature.
    pub fn with_temperature(mut self, temperature: f64) -> Self {
        self.temperature = Some(temperature);
        self
    }

    /// Set streaming mode.
    pub fn with_stream(mut self, stream: bool) -> Self {
        self.stream = stream;
        self
    }
}

/// Amazon SageMaker model provider implementation.
pub struct SageMakerModel {
    config: ModelConfig,
    endpoint_config: SageMakerEndpointConfig,
    payload_config: SageMakerPayloadConfig,
}

impl SageMakerModel {
    /// Create a new SageMaker model.
    pub fn new(endpoint_config: SageMakerEndpointConfig, payload_config: SageMakerPayloadConfig) -> Self {
        Self {
            config: ModelConfig::new(&endpoint_config.endpoint_name),
            endpoint_config,
            payload_config,
        }
    }

    /// Get the endpoint configuration.
    pub fn endpoint_config(&self) -> &SageMakerEndpointConfig {
        &self.endpoint_config
    }

    /// Get the payload configuration.
    pub fn payload_config(&self) -> &SageMakerPayloadConfig {
        &self.payload_config
    }

    /// Update the endpoint configuration.
    pub fn update_endpoint_config(&mut self, config: SageMakerEndpointConfig) {
        self.config = ModelConfig::new(&config.endpoint_name);
        self.endpoint_config = config;
    }

    /// Update the payload configuration.
    pub fn update_payload_config(&mut self, config: SageMakerPayloadConfig) {
        self.payload_config = config;
    }
}

impl Model for SageMakerModel {
    fn config(&self) -> &ModelConfig {
        &self.config
    }

    fn update_config(&mut self, config: ModelConfig) {
        self.config = config;
    }

    fn stream<'a>(
        &'a self,
        _messages: &'a [Message],
        _tool_specs: Option<&'a [ToolSpec]>,
        _system_prompt: Option<&'a str>,
        _tool_choice: Option<ToolChoice>,
        _system_prompt_content: Option<&'a [SystemContentBlock]>,
    ) -> StreamEventStream<'a> {
        Box::pin(futures::stream::once(async {
            Err(StrandsError::ModelError {
                message: "SageMaker integration requires aws-sdk-sagemakerruntime implementation".into(),
                source: None,
            })
        }))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_sagemaker_endpoint_config() {
        let config = SageMakerEndpointConfig::new("my-endpoint")
            .with_region("us-west-2")
            .with_target_model("my-model");
        
        assert_eq!(config.endpoint_name, "my-endpoint");
        assert_eq!(config.region_name, Some("us-west-2".to_string()));
        assert_eq!(config.target_model, Some("my-model".to_string()));
    }

    #[test]
    fn test_sagemaker_payload_config() {
        let config = SageMakerPayloadConfig::new()
            .with_max_tokens(1000)
            .with_temperature(0.7);
        
        assert_eq!(config.max_tokens, Some(1000));
        assert_eq!(config.temperature, Some(0.7));
        assert!(config.stream);
    }

    #[test]
    fn test_sagemaker_model_creation() {
        let endpoint_config = SageMakerEndpointConfig::new("test-endpoint");
        let payload_config = SageMakerPayloadConfig::new();
        let model = SageMakerModel::new(endpoint_config, payload_config);
        
        assert_eq!(model.config().model_id, "test-endpoint");
    }
}