use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LambdaConfig {
pub function_name: String,
pub region: String,
pub memory_mb: u32,
pub timeout_secs: u32,
pub runtime: LambdaRuntime,
pub model_uri: String,
pub environment: HashMap<String, String>,
pub provisioned_concurrency: u32,
pub vpc_config: Option<VpcConfig>,
pub ephemeral_storage_mb: u32,
pub architecture: LambdaArchitecture,
}
impl Default for LambdaConfig {
fn default() -> Self {
Self {
function_name: String::new(),
region: "us-east-1".to_string(),
memory_mb: 3008, timeout_secs: 60,
runtime: LambdaRuntime::Provided,
model_uri: String::new(),
environment: HashMap::new(),
provisioned_concurrency: 0,
vpc_config: None,
ephemeral_storage_mb: 512,
architecture: LambdaArchitecture::Arm64, }
}
}
impl LambdaConfig {
#[must_use]
pub fn new(function_name: impl Into<String>) -> Self {
Self { function_name: function_name.into(), ..Default::default() }
}
#[must_use]
pub fn with_model(mut self, model_uri: impl Into<String>) -> Self {
self.model_uri = model_uri.into();
self
}
#[must_use]
pub fn with_memory(mut self, mb: u32) -> Self {
self.memory_mb = mb.clamp(128, 10240);
self
}
#[must_use]
pub fn with_timeout(mut self, secs: u32) -> Self {
self.timeout_secs = secs.clamp(1, 900);
self
}
#[must_use]
pub fn with_region(mut self, region: impl Into<String>) -> Self {
self.region = region.into();
self
}
#[must_use]
pub fn with_runtime(mut self, runtime: LambdaRuntime) -> Self {
self.runtime = runtime;
self
}
#[must_use]
pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.environment.insert(key.into(), value.into());
self
}
#[must_use]
pub fn with_provisioned_concurrency(mut self, count: u32) -> Self {
self.provisioned_concurrency = count;
self
}
#[must_use]
pub fn with_vpc(mut self, vpc: VpcConfig) -> Self {
self.vpc_config = Some(vpc);
self
}
#[must_use]
pub fn with_storage(mut self, mb: u32) -> Self {
self.ephemeral_storage_mb = mb.clamp(512, 10240);
self
}
#[must_use]
pub fn with_architecture(mut self, arch: LambdaArchitecture) -> Self {
self.architecture = arch;
self
}
pub fn validate(&self) -> Result<(), ConfigError> {
if self.function_name.is_empty() {
return Err(ConfigError::MissingField("function_name"));
}
if self.model_uri.is_empty() {
return Err(ConfigError::MissingField("model_uri"));
}
if self.memory_mb < 128 || self.memory_mb > 10240 {
return Err(ConfigError::InvalidMemory(self.memory_mb));
}
if self.timeout_secs == 0 || self.timeout_secs > 900 {
return Err(ConfigError::InvalidTimeout(self.timeout_secs));
}
Ok(())
}
#[must_use]
pub fn estimate_cost(&self, invocations_per_month: u64, avg_duration_ms: u64) -> f64 {
let gb_seconds = (self.memory_mb as f64 / 1024.0)
* (avg_duration_ms as f64 / 1000.0)
* invocations_per_month as f64;
let compute_cost = gb_seconds * 0.0000133334;
let request_cost = (invocations_per_month as f64 / 1_000_000.0) * 0.20;
let provisioned_cost = if self.provisioned_concurrency > 0 {
let gb_provisioned =
(self.memory_mb as f64 / 1024.0) * self.provisioned_concurrency as f64;
gb_provisioned * 0.0000041667 * 3600.0 * 24.0 * 30.0 } else {
0.0
};
compute_cost + request_cost + provisioned_cost
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum LambdaRuntime {
#[default]
Provided,
Python312,
Container,
}
impl LambdaRuntime {
#[must_use]
pub const fn identifier(&self) -> &'static str {
match self {
Self::Provided => "provided.al2023",
Self::Python312 => "python3.12",
Self::Container => "container",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum LambdaArchitecture {
#[default]
Arm64,
X86_64,
}
impl LambdaArchitecture {
#[must_use]
pub const fn identifier(&self) -> &'static str {
match self {
Self::Arm64 => "arm64",
Self::X86_64 => "x86_64",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VpcConfig {
pub subnet_ids: Vec<String>,
pub security_group_ids: Vec<String>,
}
impl VpcConfig {
#[must_use]
pub fn new(subnet_ids: Vec<String>, security_group_ids: Vec<String>) -> Self {
Self { subnet_ids, security_group_ids }
}
}
#[derive(Debug, Clone)]
pub struct LambdaDeployer {
config: LambdaConfig,
status: DeploymentStatus,
function_arn: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum DeploymentStatus {
#[default]
NotDeployed,
Packaging,
Uploading,
Deploying,
Active,
Failed,
}
impl LambdaDeployer {
#[must_use]
pub fn new(config: LambdaConfig) -> Self {
Self { config, status: DeploymentStatus::NotDeployed, function_arn: None }
}
#[must_use]
pub fn status(&self) -> DeploymentStatus {
self.status
}
#[must_use]
pub fn function_arn(&self) -> Option<&str> {
self.function_arn.as_deref()
}
#[must_use]
pub fn config(&self) -> &LambdaConfig {
&self.config
}
pub fn validate(&self) -> Result<(), DeploymentError> {
self.config.validate().map_err(DeploymentError::Config)?;
Ok(())
}
#[must_use]
pub fn estimate(&self) -> DeploymentEstimate {
let model_size_mb = 1024; let package_size_mb = model_size_mb + 50;
DeploymentEstimate {
package_size_mb,
estimated_cold_start_ms: estimate_cold_start(&self.config),
monthly_cost_1k_req: self.config.estimate_cost(1000, 500),
monthly_cost_100k_req: self.config.estimate_cost(100_000, 500),
monthly_cost_1m_req: self.config.estimate_cost(1_000_000, 500),
}
}
#[must_use]
pub fn generate_iac(&self) -> String {
let vpc_config = if let Some(ref vpc) = self.config.vpc_config {
format!(
r"
VpcConfig:
SubnetIds:
{}
SecurityGroupIds:
{}",
vpc.subnet_ids
.iter()
.map(|s| format!("- {s}"))
.collect::<Vec<_>>()
.join("\n "),
vpc.security_group_ids
.iter()
.map(|s| format!("- {s}"))
.collect::<Vec<_>>()
.join("\n ")
)
} else {
String::new()
};
let _provisioned = if self.config.provisioned_concurrency > 0 {
format!(
r"
{}Concurrency:
Type: AWS::Lambda::Version
Properties:
FunctionName: !Ref {}Function
ProvisionedConcurrencyConfig:
ProvisionedConcurrentExecutions: {}",
self.config.function_name,
self.config.function_name,
self.config.provisioned_concurrency
)
} else {
String::new()
};
format!(
r"AWSTemplateFormatVersion: '2010-09-09'
Transform: AWS::Serverless-2016-10-31
Description: ML Inference Lambda - {}
Resources:
{}Function:
Type: AWS::Serverless::Function
Properties:
FunctionName: {}
Runtime: {}
Handler: bootstrap
CodeUri: ./deployment-package.zip
MemorySize: {}
Timeout: {}
Architectures:
- {}
Environment:
Variables:
MODEL_URI: {}
RUST_LOG: info{}{}
EphemeralStorage:
Size: {}
Outputs:
FunctionArn:
Description: Lambda Function ARN
Value: !GetAtt {}Function.Arn
FunctionUrl:
Description: Lambda Function URL
Value: !GetAtt {}FunctionUrl.FunctionUrl",
self.config.function_name,
self.config.function_name,
self.config.function_name,
self.config.runtime.identifier(),
self.config.memory_mb,
self.config.timeout_secs,
self.config.architecture.identifier(),
self.config.model_uri,
self.config
.environment
.iter()
.map(|(k, v)| format!("\n {k}: {v}"))
.collect::<String>(),
vpc_config,
self.config.ephemeral_storage_mb,
self.config.function_name,
self.config.function_name,
)
}
pub fn set_status(&mut self, status: DeploymentStatus) {
self.status = status;
}
pub fn set_function_arn(&mut self, arn: String) {
self.function_arn = Some(arn);
self.status = DeploymentStatus::Active;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeploymentEstimate {
pub package_size_mb: u64,
pub estimated_cold_start_ms: u64,
pub monthly_cost_1k_req: f64,
pub monthly_cost_100k_req: f64,
pub monthly_cost_1m_req: f64,
}
fn estimate_cold_start(config: &LambdaConfig) -> u64 {
let base_ms: u64 = match config.runtime {
LambdaRuntime::Provided => 100,
LambdaRuntime::Python312 => 200,
LambdaRuntime::Container => 500,
};
let memory_factor =
if config.memory_mb >= 3008 { 1.0 } else { 1.5 - (config.memory_mb as f64 / 6016.0) };
let model_load_ms: u64 = 2000;
((base_ms as f64 * memory_factor) as u64) + model_load_ms
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceRequest {
pub input: String,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub parameters: HashMap<String, serde_json::Value>,
}
impl InferenceRequest {
#[must_use]
pub fn new(input: impl Into<String>) -> Self {
Self {
input: input.into(),
max_tokens: None,
temperature: None,
parameters: HashMap::new(),
}
}
#[must_use]
pub fn with_max_tokens(mut self, tokens: u32) -> Self {
self.max_tokens = Some(tokens);
self
}
#[must_use]
pub fn with_temperature(mut self, temp: f32) -> Self {
self.temperature = Some(temp);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceResponse {
pub output: String,
pub tokens_generated: u32,
pub latency_ms: u64,
pub cold_start: bool,
}
#[derive(Debug, Clone)]
pub struct LambdaClient {
function_arn: String,
region: String,
timeout: Duration,
}
impl LambdaClient {
#[must_use]
pub fn new(function_arn: impl Into<String>, region: impl Into<String>) -> Self {
Self {
function_arn: function_arn.into(),
region: region.into(),
timeout: Duration::from_secs(60),
}
}
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
#[must_use]
pub fn function_arn(&self) -> &str {
&self.function_arn
}
#[must_use]
pub fn region(&self) -> &str {
&self.region
}
#[must_use]
pub fn timeout(&self) -> Duration {
self.timeout
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConfigError {
MissingField(&'static str),
InvalidMemory(u32),
InvalidTimeout(u32),
}
impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::MissingField(field) => write!(f, "Missing required field: {field}"),
Self::InvalidMemory(mb) => write!(f, "Invalid memory size: {mb}MB (must be 128-10240)"),
Self::InvalidTimeout(secs) => write!(f, "Invalid timeout: {secs}s (must be 1-900)"),
}
}
}
impl std::error::Error for ConfigError {}
#[derive(Debug)]
pub enum DeploymentError {
Config(ConfigError),
AwsError(String),
ModelNotFound(String),
PackageTooLarge { size_mb: u64, max_mb: u64 },
}
impl std::fmt::Display for DeploymentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Config(e) => write!(f, "Configuration error: {e}"),
Self::AwsError(e) => write!(f, "AWS error: {e}"),
Self::ModelNotFound(uri) => write!(f, "Model not found: {uri}"),
Self::PackageTooLarge { size_mb, max_mb } => {
write!(f, "Package too large: {size_mb}MB (max {max_mb}MB)")
}
}
}
}
impl std::error::Error for DeploymentError {}
#[cfg(test)]
#[allow(non_snake_case)]
#[path = "lambda_tests.rs"]
mod tests;