Skip to main content

batuta/serve/
lambda.rs

1//! AWS Lambda Inference Deployment
2//!
3//! Deploy and manage ML inference on AWS Lambda for serverless, pay-per-use inference.
4//!
5//! ## Features
6//!
7//! - Model packaging with Docker/OCI containers
8//! - Cold start optimization with provisioned concurrency
9//! - Automatic scaling with Lambda's built-in capabilities
10//! - Integration with Pacha registry for model artifacts
11//!
12//! ## Toyota Way Principles
13//!
14//! - Muda Elimination: Pay only for actual inference compute
15//! - Heijunka: Automatic scaling levels inference load
16//! - Jidoka: Built-in error handling and retry logic
17
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::time::Duration;
21
22// ============================================================================
23// SERVE-LAM-001: Lambda Configuration
24// ============================================================================
25
26/// Lambda function configuration for inference
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct LambdaConfig {
29    /// Function name (unique identifier)
30    pub function_name: String,
31    /// AWS region
32    pub region: String,
33    /// Memory size in MB (128-10240)
34    pub memory_mb: u32,
35    /// Timeout in seconds (max 900 for Lambda)
36    pub timeout_secs: u32,
37    /// Runtime environment
38    pub runtime: LambdaRuntime,
39    /// Model reference (Pacha URI)
40    pub model_uri: String,
41    /// Environment variables
42    pub environment: HashMap<String, String>,
43    /// Provisioned concurrency (0 = on-demand only)
44    pub provisioned_concurrency: u32,
45    /// VPC configuration (for Private tier)
46    pub vpc_config: Option<VpcConfig>,
47    /// Ephemeral storage in MB (512-10240)
48    pub ephemeral_storage_mb: u32,
49    /// Architecture
50    pub architecture: LambdaArchitecture,
51}
52
53impl Default for LambdaConfig {
54    fn default() -> Self {
55        Self {
56            function_name: String::new(),
57            region: "us-east-1".to_string(),
58            memory_mb: 3008, // Good for inference
59            timeout_secs: 60,
60            runtime: LambdaRuntime::Provided,
61            model_uri: String::new(),
62            environment: HashMap::new(),
63            provisioned_concurrency: 0,
64            vpc_config: None,
65            ephemeral_storage_mb: 512,
66            architecture: LambdaArchitecture::Arm64, // Better price/perf
67        }
68    }
69}
70
71impl LambdaConfig {
72    /// Create a new Lambda config with function name
73    #[must_use]
74    pub fn new(function_name: impl Into<String>) -> Self {
75        Self { function_name: function_name.into(), ..Default::default() }
76    }
77
78    /// Set the model URI
79    #[must_use]
80    pub fn with_model(mut self, model_uri: impl Into<String>) -> Self {
81        self.model_uri = model_uri.into();
82        self
83    }
84
85    /// Set memory size
86    #[must_use]
87    pub fn with_memory(mut self, mb: u32) -> Self {
88        self.memory_mb = mb.clamp(128, 10240);
89        self
90    }
91
92    /// Set timeout
93    #[must_use]
94    pub fn with_timeout(mut self, secs: u32) -> Self {
95        self.timeout_secs = secs.clamp(1, 900);
96        self
97    }
98
99    /// Set region
100    #[must_use]
101    pub fn with_region(mut self, region: impl Into<String>) -> Self {
102        self.region = region.into();
103        self
104    }
105
106    /// Set runtime
107    #[must_use]
108    pub fn with_runtime(mut self, runtime: LambdaRuntime) -> Self {
109        self.runtime = runtime;
110        self
111    }
112
113    /// Add environment variable
114    #[must_use]
115    pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
116        self.environment.insert(key.into(), value.into());
117        self
118    }
119
120    /// Set provisioned concurrency
121    #[must_use]
122    pub fn with_provisioned_concurrency(mut self, count: u32) -> Self {
123        self.provisioned_concurrency = count;
124        self
125    }
126
127    /// Set VPC configuration
128    #[must_use]
129    pub fn with_vpc(mut self, vpc: VpcConfig) -> Self {
130        self.vpc_config = Some(vpc);
131        self
132    }
133
134    /// Set ephemeral storage
135    #[must_use]
136    pub fn with_storage(mut self, mb: u32) -> Self {
137        self.ephemeral_storage_mb = mb.clamp(512, 10240);
138        self
139    }
140
141    /// Set architecture
142    #[must_use]
143    pub fn with_architecture(mut self, arch: LambdaArchitecture) -> Self {
144        self.architecture = arch;
145        self
146    }
147
148    /// Validate configuration
149    pub fn validate(&self) -> Result<(), ConfigError> {
150        if self.function_name.is_empty() {
151            return Err(ConfigError::MissingField("function_name"));
152        }
153        if self.model_uri.is_empty() {
154            return Err(ConfigError::MissingField("model_uri"));
155        }
156        if self.memory_mb < 128 || self.memory_mb > 10240 {
157            return Err(ConfigError::InvalidMemory(self.memory_mb));
158        }
159        if self.timeout_secs == 0 || self.timeout_secs > 900 {
160            return Err(ConfigError::InvalidTimeout(self.timeout_secs));
161        }
162        Ok(())
163    }
164
165    /// Estimate monthly cost for given invocations
166    #[must_use]
167    pub fn estimate_cost(&self, invocations_per_month: u64, avg_duration_ms: u64) -> f64 {
168        // Lambda pricing (approximate, us-east-1, ARM):
169        // $0.0000133334 per GB-second
170        // $0.20 per 1M requests
171        // Provisioned: $0.0000041667 per GB-second provisioned
172
173        let gb_seconds = (self.memory_mb as f64 / 1024.0)
174            * (avg_duration_ms as f64 / 1000.0)
175            * invocations_per_month as f64;
176
177        let compute_cost = gb_seconds * 0.0000133334;
178        let request_cost = (invocations_per_month as f64 / 1_000_000.0) * 0.20;
179
180        // Add provisioned concurrency cost (per hour)
181        let provisioned_cost = if self.provisioned_concurrency > 0 {
182            let gb_provisioned =
183                (self.memory_mb as f64 / 1024.0) * self.provisioned_concurrency as f64;
184            gb_provisioned * 0.0000041667 * 3600.0 * 24.0 * 30.0 // Monthly
185        } else {
186            0.0
187        };
188
189        compute_cost + request_cost + provisioned_cost
190    }
191}
192
193/// Lambda runtime environment
194#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
195pub enum LambdaRuntime {
196    /// Custom runtime (provided.al2023)
197    #[default]
198    Provided,
199    /// Python 3.12
200    Python312,
201    /// Container image (Docker)
202    Container,
203}
204
205impl LambdaRuntime {
206    /// Get the AWS runtime identifier
207    #[must_use]
208    pub const fn identifier(&self) -> &'static str {
209        match self {
210            Self::Provided => "provided.al2023",
211            Self::Python312 => "python3.12",
212            Self::Container => "container",
213        }
214    }
215}
216
217/// Lambda architecture
218#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
219pub enum LambdaArchitecture {
220    /// ARM64 (Graviton2) - better price/performance
221    #[default]
222    Arm64,
223    /// x86_64
224    X86_64,
225}
226
227impl LambdaArchitecture {
228    /// Get the AWS architecture identifier
229    #[must_use]
230    pub const fn identifier(&self) -> &'static str {
231        match self {
232            Self::Arm64 => "arm64",
233            Self::X86_64 => "x86_64",
234        }
235    }
236}
237
238/// VPC configuration for Lambda
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct VpcConfig {
241    /// Subnet IDs
242    pub subnet_ids: Vec<String>,
243    /// Security group IDs
244    pub security_group_ids: Vec<String>,
245}
246
247impl VpcConfig {
248    /// Create new VPC config
249    #[must_use]
250    pub fn new(subnet_ids: Vec<String>, security_group_ids: Vec<String>) -> Self {
251        Self { subnet_ids, security_group_ids }
252    }
253}
254
255// ============================================================================
256// SERVE-LAM-002: Lambda Deployer
257// ============================================================================
258
259/// Lambda deployment manager
260#[derive(Debug, Clone)]
261pub struct LambdaDeployer {
262    /// Deployment configuration
263    config: LambdaConfig,
264    /// Deployment status
265    status: DeploymentStatus,
266    /// Function ARN (after deployment)
267    function_arn: Option<String>,
268}
269
270/// Deployment status
271#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
272pub enum DeploymentStatus {
273    /// Not yet deployed
274    #[default]
275    NotDeployed,
276    /// Packaging model
277    Packaging,
278    /// Uploading to S3
279    Uploading,
280    /// Creating/updating function
281    Deploying,
282    /// Active and ready
283    Active,
284    /// Deployment failed
285    Failed,
286}
287
288impl LambdaDeployer {
289    /// Create a new deployer
290    #[must_use]
291    pub fn new(config: LambdaConfig) -> Self {
292        Self { config, status: DeploymentStatus::NotDeployed, function_arn: None }
293    }
294
295    /// Get current deployment status
296    #[must_use]
297    pub fn status(&self) -> DeploymentStatus {
298        self.status
299    }
300
301    /// Get function ARN (if deployed)
302    #[must_use]
303    pub fn function_arn(&self) -> Option<&str> {
304        self.function_arn.as_deref()
305    }
306
307    /// Get configuration
308    #[must_use]
309    pub fn config(&self) -> &LambdaConfig {
310        &self.config
311    }
312
313    /// Validate deployment prerequisites
314    pub fn validate(&self) -> Result<(), DeploymentError> {
315        self.config.validate().map_err(DeploymentError::Config)?;
316        Ok(())
317    }
318
319    /// Estimate deployment (dry run)
320    #[must_use]
321    pub fn estimate(&self) -> DeploymentEstimate {
322        let model_size_mb = 1024; // Placeholder - would be fetched from registry
323        let package_size_mb = model_size_mb + 50; // Model + runtime
324
325        DeploymentEstimate {
326            package_size_mb,
327            estimated_cold_start_ms: estimate_cold_start(&self.config),
328            monthly_cost_1k_req: self.config.estimate_cost(1000, 500),
329            monthly_cost_100k_req: self.config.estimate_cost(100_000, 500),
330            monthly_cost_1m_req: self.config.estimate_cost(1_000_000, 500),
331        }
332    }
333
334    /// Generate infrastructure-as-code (CloudFormation/SAM template)
335    #[must_use]
336    pub fn generate_iac(&self) -> String {
337        let vpc_config = if let Some(ref vpc) = self.config.vpc_config {
338            format!(
339                r"
340      VpcConfig:
341        SubnetIds:
342          {}
343        SecurityGroupIds:
344          {}",
345                vpc.subnet_ids
346                    .iter()
347                    .map(|s| format!("- {s}"))
348                    .collect::<Vec<_>>()
349                    .join("\n          "),
350                vpc.security_group_ids
351                    .iter()
352                    .map(|s| format!("- {s}"))
353                    .collect::<Vec<_>>()
354                    .join("\n          ")
355            )
356        } else {
357            String::new()
358        };
359
360        let _provisioned = if self.config.provisioned_concurrency > 0 {
361            format!(
362                r"
363  {}Concurrency:
364    Type: AWS::Lambda::Version
365    Properties:
366      FunctionName: !Ref {}Function
367      ProvisionedConcurrencyConfig:
368        ProvisionedConcurrentExecutions: {}",
369                self.config.function_name,
370                self.config.function_name,
371                self.config.provisioned_concurrency
372            )
373        } else {
374            String::new()
375        };
376
377        format!(
378            r"AWSTemplateFormatVersion: '2010-09-09'
379Transform: AWS::Serverless-2016-10-31
380Description: ML Inference Lambda - {}
381
382Resources:
383  {}Function:
384    Type: AWS::Serverless::Function
385    Properties:
386      FunctionName: {}
387      Runtime: {}
388      Handler: bootstrap
389      CodeUri: ./deployment-package.zip
390      MemorySize: {}
391      Timeout: {}
392      Architectures:
393        - {}
394      Environment:
395        Variables:
396          MODEL_URI: {}
397          RUST_LOG: info{}{}
398      EphemeralStorage:
399        Size: {}
400
401Outputs:
402  FunctionArn:
403    Description: Lambda Function ARN
404    Value: !GetAtt {}Function.Arn
405  FunctionUrl:
406    Description: Lambda Function URL
407    Value: !GetAtt {}FunctionUrl.FunctionUrl",
408            self.config.function_name,
409            self.config.function_name,
410            self.config.function_name,
411            self.config.runtime.identifier(),
412            self.config.memory_mb,
413            self.config.timeout_secs,
414            self.config.architecture.identifier(),
415            self.config.model_uri,
416            self.config
417                .environment
418                .iter()
419                .map(|(k, v)| format!("\n          {k}: {v}"))
420                .collect::<String>(),
421            vpc_config,
422            self.config.ephemeral_storage_mb,
423            self.config.function_name,
424            self.config.function_name,
425        )
426    }
427
428    /// Set deployment status (for tracking)
429    pub fn set_status(&mut self, status: DeploymentStatus) {
430        self.status = status;
431    }
432
433    /// Set function ARN after successful deployment
434    pub fn set_function_arn(&mut self, arn: String) {
435        self.function_arn = Some(arn);
436        self.status = DeploymentStatus::Active;
437    }
438}
439
440/// Deployment estimate
441#[derive(Debug, Clone, Serialize, Deserialize)]
442pub struct DeploymentEstimate {
443    /// Package size in MB
444    pub package_size_mb: u64,
445    /// Estimated cold start in ms
446    pub estimated_cold_start_ms: u64,
447    /// Monthly cost for 1K requests
448    pub monthly_cost_1k_req: f64,
449    /// Monthly cost for 100K requests
450    pub monthly_cost_100k_req: f64,
451    /// Monthly cost for 1M requests
452    pub monthly_cost_1m_req: f64,
453}
454
455/// Estimate cold start time based on config
456fn estimate_cold_start(config: &LambdaConfig) -> u64 {
457    // Base cold start for provided runtime
458    let base_ms: u64 = match config.runtime {
459        LambdaRuntime::Provided => 100,
460        LambdaRuntime::Python312 => 200,
461        LambdaRuntime::Container => 500,
462    };
463
464    // Memory affects cold start (more memory = faster init)
465    let memory_factor =
466        if config.memory_mb >= 3008 { 1.0 } else { 1.5 - (config.memory_mb as f64 / 6016.0) };
467
468    // Model loading estimate (rough)
469    let model_load_ms: u64 = 2000; // 2 seconds for model loading
470
471    ((base_ms as f64 * memory_factor) as u64) + model_load_ms
472}
473
474// ============================================================================
475// SERVE-LAM-003: Inference Client
476// ============================================================================
477
478/// Lambda inference request
479#[derive(Debug, Clone, Serialize, Deserialize)]
480pub struct InferenceRequest {
481    /// Input text/prompt
482    pub input: String,
483    /// Maximum tokens to generate
484    pub max_tokens: Option<u32>,
485    /// Temperature (0.0-2.0)
486    pub temperature: Option<f32>,
487    /// Additional parameters
488    pub parameters: HashMap<String, serde_json::Value>,
489}
490
491impl InferenceRequest {
492    /// Create a new inference request
493    #[must_use]
494    pub fn new(input: impl Into<String>) -> Self {
495        Self {
496            input: input.into(),
497            max_tokens: None,
498            temperature: None,
499            parameters: HashMap::new(),
500        }
501    }
502
503    /// Set max tokens
504    #[must_use]
505    pub fn with_max_tokens(mut self, tokens: u32) -> Self {
506        self.max_tokens = Some(tokens);
507        self
508    }
509
510    /// Set temperature
511    #[must_use]
512    pub fn with_temperature(mut self, temp: f32) -> Self {
513        self.temperature = Some(temp);
514        self
515    }
516}
517
518/// Lambda inference response
519#[derive(Debug, Clone, Serialize, Deserialize)]
520pub struct InferenceResponse {
521    /// Generated output
522    pub output: String,
523    /// Number of tokens generated
524    pub tokens_generated: u32,
525    /// Inference latency in ms
526    pub latency_ms: u64,
527    /// Whether this was a cold start
528    pub cold_start: bool,
529}
530
531/// Lambda inference client
532#[derive(Debug, Clone)]
533pub struct LambdaClient {
534    /// Function ARN or name
535    function_arn: String,
536    /// AWS region
537    region: String,
538    /// Invocation timeout
539    timeout: Duration,
540}
541
542impl LambdaClient {
543    /// Create a new Lambda client
544    #[must_use]
545    pub fn new(function_arn: impl Into<String>, region: impl Into<String>) -> Self {
546        Self {
547            function_arn: function_arn.into(),
548            region: region.into(),
549            timeout: Duration::from_secs(60),
550        }
551    }
552
553    /// Set invocation timeout
554    #[must_use]
555    pub fn with_timeout(mut self, timeout: Duration) -> Self {
556        self.timeout = timeout;
557        self
558    }
559
560    /// Get function ARN
561    #[must_use]
562    pub fn function_arn(&self) -> &str {
563        &self.function_arn
564    }
565
566    /// Get region
567    #[must_use]
568    pub fn region(&self) -> &str {
569        &self.region
570    }
571
572    /// Get timeout
573    #[must_use]
574    pub fn timeout(&self) -> Duration {
575        self.timeout
576    }
577}
578
579// ============================================================================
580// SERVE-LAM-004: Error Types
581// ============================================================================
582
583/// Configuration error
584#[derive(Debug, Clone, PartialEq, Eq)]
585pub enum ConfigError {
586    /// Missing required field
587    MissingField(&'static str),
588    /// Invalid memory size
589    InvalidMemory(u32),
590    /// Invalid timeout
591    InvalidTimeout(u32),
592}
593
594impl std::fmt::Display for ConfigError {
595    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
596        match self {
597            Self::MissingField(field) => write!(f, "Missing required field: {field}"),
598            Self::InvalidMemory(mb) => write!(f, "Invalid memory size: {mb}MB (must be 128-10240)"),
599            Self::InvalidTimeout(secs) => write!(f, "Invalid timeout: {secs}s (must be 1-900)"),
600        }
601    }
602}
603
604impl std::error::Error for ConfigError {}
605
606/// Deployment error
607#[derive(Debug)]
608pub enum DeploymentError {
609    /// Configuration error
610    Config(ConfigError),
611    /// AWS API error
612    AwsError(String),
613    /// Model not found
614    ModelNotFound(String),
615    /// Package too large
616    PackageTooLarge { size_mb: u64, max_mb: u64 },
617}
618
619impl std::fmt::Display for DeploymentError {
620    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
621        match self {
622            Self::Config(e) => write!(f, "Configuration error: {e}"),
623            Self::AwsError(e) => write!(f, "AWS error: {e}"),
624            Self::ModelNotFound(uri) => write!(f, "Model not found: {uri}"),
625            Self::PackageTooLarge { size_mb, max_mb } => {
626                write!(f, "Package too large: {size_mb}MB (max {max_mb}MB)")
627            }
628        }
629    }
630}
631
632impl std::error::Error for DeploymentError {}
633
634// ============================================================================
635// Tests
636// ============================================================================
637
638#[cfg(test)]
639#[allow(non_snake_case)]
640#[path = "lambda_tests.rs"]
641mod tests;