1use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::time::Duration;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct LambdaConfig {
29 pub function_name: String,
31 pub region: String,
33 pub memory_mb: u32,
35 pub timeout_secs: u32,
37 pub runtime: LambdaRuntime,
39 pub model_uri: String,
41 pub environment: HashMap<String, String>,
43 pub provisioned_concurrency: u32,
45 pub vpc_config: Option<VpcConfig>,
47 pub ephemeral_storage_mb: u32,
49 pub architecture: LambdaArchitecture,
51}
52
53impl Default for LambdaConfig {
54 fn default() -> Self {
55 Self {
56 function_name: String::new(),
57 region: "us-east-1".to_string(),
58 memory_mb: 3008, timeout_secs: 60,
60 runtime: LambdaRuntime::Provided,
61 model_uri: String::new(),
62 environment: HashMap::new(),
63 provisioned_concurrency: 0,
64 vpc_config: None,
65 ephemeral_storage_mb: 512,
66 architecture: LambdaArchitecture::Arm64, }
68 }
69}
70
71impl LambdaConfig {
72 #[must_use]
74 pub fn new(function_name: impl Into<String>) -> Self {
75 Self { function_name: function_name.into(), ..Default::default() }
76 }
77
78 #[must_use]
80 pub fn with_model(mut self, model_uri: impl Into<String>) -> Self {
81 self.model_uri = model_uri.into();
82 self
83 }
84
85 #[must_use]
87 pub fn with_memory(mut self, mb: u32) -> Self {
88 self.memory_mb = mb.clamp(128, 10240);
89 self
90 }
91
92 #[must_use]
94 pub fn with_timeout(mut self, secs: u32) -> Self {
95 self.timeout_secs = secs.clamp(1, 900);
96 self
97 }
98
99 #[must_use]
101 pub fn with_region(mut self, region: impl Into<String>) -> Self {
102 self.region = region.into();
103 self
104 }
105
106 #[must_use]
108 pub fn with_runtime(mut self, runtime: LambdaRuntime) -> Self {
109 self.runtime = runtime;
110 self
111 }
112
113 #[must_use]
115 pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
116 self.environment.insert(key.into(), value.into());
117 self
118 }
119
120 #[must_use]
122 pub fn with_provisioned_concurrency(mut self, count: u32) -> Self {
123 self.provisioned_concurrency = count;
124 self
125 }
126
127 #[must_use]
129 pub fn with_vpc(mut self, vpc: VpcConfig) -> Self {
130 self.vpc_config = Some(vpc);
131 self
132 }
133
134 #[must_use]
136 pub fn with_storage(mut self, mb: u32) -> Self {
137 self.ephemeral_storage_mb = mb.clamp(512, 10240);
138 self
139 }
140
141 #[must_use]
143 pub fn with_architecture(mut self, arch: LambdaArchitecture) -> Self {
144 self.architecture = arch;
145 self
146 }
147
148 pub fn validate(&self) -> Result<(), ConfigError> {
150 if self.function_name.is_empty() {
151 return Err(ConfigError::MissingField("function_name"));
152 }
153 if self.model_uri.is_empty() {
154 return Err(ConfigError::MissingField("model_uri"));
155 }
156 if self.memory_mb < 128 || self.memory_mb > 10240 {
157 return Err(ConfigError::InvalidMemory(self.memory_mb));
158 }
159 if self.timeout_secs == 0 || self.timeout_secs > 900 {
160 return Err(ConfigError::InvalidTimeout(self.timeout_secs));
161 }
162 Ok(())
163 }
164
165 #[must_use]
167 pub fn estimate_cost(&self, invocations_per_month: u64, avg_duration_ms: u64) -> f64 {
168 let gb_seconds = (self.memory_mb as f64 / 1024.0)
174 * (avg_duration_ms as f64 / 1000.0)
175 * invocations_per_month as f64;
176
177 let compute_cost = gb_seconds * 0.0000133334;
178 let request_cost = (invocations_per_month as f64 / 1_000_000.0) * 0.20;
179
180 let provisioned_cost = if self.provisioned_concurrency > 0 {
182 let gb_provisioned =
183 (self.memory_mb as f64 / 1024.0) * self.provisioned_concurrency as f64;
184 gb_provisioned * 0.0000041667 * 3600.0 * 24.0 * 30.0 } else {
186 0.0
187 };
188
189 compute_cost + request_cost + provisioned_cost
190 }
191}
192
193#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
195pub enum LambdaRuntime {
196 #[default]
198 Provided,
199 Python312,
201 Container,
203}
204
205impl LambdaRuntime {
206 #[must_use]
208 pub const fn identifier(&self) -> &'static str {
209 match self {
210 Self::Provided => "provided.al2023",
211 Self::Python312 => "python3.12",
212 Self::Container => "container",
213 }
214 }
215}
216
217#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
219pub enum LambdaArchitecture {
220 #[default]
222 Arm64,
223 X86_64,
225}
226
227impl LambdaArchitecture {
228 #[must_use]
230 pub const fn identifier(&self) -> &'static str {
231 match self {
232 Self::Arm64 => "arm64",
233 Self::X86_64 => "x86_64",
234 }
235 }
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct VpcConfig {
241 pub subnet_ids: Vec<String>,
243 pub security_group_ids: Vec<String>,
245}
246
247impl VpcConfig {
248 #[must_use]
250 pub fn new(subnet_ids: Vec<String>, security_group_ids: Vec<String>) -> Self {
251 Self { subnet_ids, security_group_ids }
252 }
253}
254
255#[derive(Debug, Clone)]
261pub struct LambdaDeployer {
262 config: LambdaConfig,
264 status: DeploymentStatus,
266 function_arn: Option<String>,
268}
269
270#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
272pub enum DeploymentStatus {
273 #[default]
275 NotDeployed,
276 Packaging,
278 Uploading,
280 Deploying,
282 Active,
284 Failed,
286}
287
288impl LambdaDeployer {
289 #[must_use]
291 pub fn new(config: LambdaConfig) -> Self {
292 Self { config, status: DeploymentStatus::NotDeployed, function_arn: None }
293 }
294
295 #[must_use]
297 pub fn status(&self) -> DeploymentStatus {
298 self.status
299 }
300
301 #[must_use]
303 pub fn function_arn(&self) -> Option<&str> {
304 self.function_arn.as_deref()
305 }
306
307 #[must_use]
309 pub fn config(&self) -> &LambdaConfig {
310 &self.config
311 }
312
313 pub fn validate(&self) -> Result<(), DeploymentError> {
315 self.config.validate().map_err(DeploymentError::Config)?;
316 Ok(())
317 }
318
319 #[must_use]
321 pub fn estimate(&self) -> DeploymentEstimate {
322 let model_size_mb = 1024; let package_size_mb = model_size_mb + 50; DeploymentEstimate {
326 package_size_mb,
327 estimated_cold_start_ms: estimate_cold_start(&self.config),
328 monthly_cost_1k_req: self.config.estimate_cost(1000, 500),
329 monthly_cost_100k_req: self.config.estimate_cost(100_000, 500),
330 monthly_cost_1m_req: self.config.estimate_cost(1_000_000, 500),
331 }
332 }
333
334 #[must_use]
336 pub fn generate_iac(&self) -> String {
337 let vpc_config = if let Some(ref vpc) = self.config.vpc_config {
338 format!(
339 r"
340 VpcConfig:
341 SubnetIds:
342 {}
343 SecurityGroupIds:
344 {}",
345 vpc.subnet_ids
346 .iter()
347 .map(|s| format!("- {s}"))
348 .collect::<Vec<_>>()
349 .join("\n "),
350 vpc.security_group_ids
351 .iter()
352 .map(|s| format!("- {s}"))
353 .collect::<Vec<_>>()
354 .join("\n ")
355 )
356 } else {
357 String::new()
358 };
359
360 let _provisioned = if self.config.provisioned_concurrency > 0 {
361 format!(
362 r"
363 {}Concurrency:
364 Type: AWS::Lambda::Version
365 Properties:
366 FunctionName: !Ref {}Function
367 ProvisionedConcurrencyConfig:
368 ProvisionedConcurrentExecutions: {}",
369 self.config.function_name,
370 self.config.function_name,
371 self.config.provisioned_concurrency
372 )
373 } else {
374 String::new()
375 };
376
377 format!(
378 r"AWSTemplateFormatVersion: '2010-09-09'
379Transform: AWS::Serverless-2016-10-31
380Description: ML Inference Lambda - {}
381
382Resources:
383 {}Function:
384 Type: AWS::Serverless::Function
385 Properties:
386 FunctionName: {}
387 Runtime: {}
388 Handler: bootstrap
389 CodeUri: ./deployment-package.zip
390 MemorySize: {}
391 Timeout: {}
392 Architectures:
393 - {}
394 Environment:
395 Variables:
396 MODEL_URI: {}
397 RUST_LOG: info{}{}
398 EphemeralStorage:
399 Size: {}
400
401Outputs:
402 FunctionArn:
403 Description: Lambda Function ARN
404 Value: !GetAtt {}Function.Arn
405 FunctionUrl:
406 Description: Lambda Function URL
407 Value: !GetAtt {}FunctionUrl.FunctionUrl",
408 self.config.function_name,
409 self.config.function_name,
410 self.config.function_name,
411 self.config.runtime.identifier(),
412 self.config.memory_mb,
413 self.config.timeout_secs,
414 self.config.architecture.identifier(),
415 self.config.model_uri,
416 self.config
417 .environment
418 .iter()
419 .map(|(k, v)| format!("\n {k}: {v}"))
420 .collect::<String>(),
421 vpc_config,
422 self.config.ephemeral_storage_mb,
423 self.config.function_name,
424 self.config.function_name,
425 )
426 }
427
428 pub fn set_status(&mut self, status: DeploymentStatus) {
430 self.status = status;
431 }
432
433 pub fn set_function_arn(&mut self, arn: String) {
435 self.function_arn = Some(arn);
436 self.status = DeploymentStatus::Active;
437 }
438}
439
440#[derive(Debug, Clone, Serialize, Deserialize)]
442pub struct DeploymentEstimate {
443 pub package_size_mb: u64,
445 pub estimated_cold_start_ms: u64,
447 pub monthly_cost_1k_req: f64,
449 pub monthly_cost_100k_req: f64,
451 pub monthly_cost_1m_req: f64,
453}
454
455fn estimate_cold_start(config: &LambdaConfig) -> u64 {
457 let base_ms: u64 = match config.runtime {
459 LambdaRuntime::Provided => 100,
460 LambdaRuntime::Python312 => 200,
461 LambdaRuntime::Container => 500,
462 };
463
464 let memory_factor =
466 if config.memory_mb >= 3008 { 1.0 } else { 1.5 - (config.memory_mb as f64 / 6016.0) };
467
468 let model_load_ms: u64 = 2000; ((base_ms as f64 * memory_factor) as u64) + model_load_ms
472}
473
474#[derive(Debug, Clone, Serialize, Deserialize)]
480pub struct InferenceRequest {
481 pub input: String,
483 pub max_tokens: Option<u32>,
485 pub temperature: Option<f32>,
487 pub parameters: HashMap<String, serde_json::Value>,
489}
490
491impl InferenceRequest {
492 #[must_use]
494 pub fn new(input: impl Into<String>) -> Self {
495 Self {
496 input: input.into(),
497 max_tokens: None,
498 temperature: None,
499 parameters: HashMap::new(),
500 }
501 }
502
503 #[must_use]
505 pub fn with_max_tokens(mut self, tokens: u32) -> Self {
506 self.max_tokens = Some(tokens);
507 self
508 }
509
510 #[must_use]
512 pub fn with_temperature(mut self, temp: f32) -> Self {
513 self.temperature = Some(temp);
514 self
515 }
516}
517
518#[derive(Debug, Clone, Serialize, Deserialize)]
520pub struct InferenceResponse {
521 pub output: String,
523 pub tokens_generated: u32,
525 pub latency_ms: u64,
527 pub cold_start: bool,
529}
530
531#[derive(Debug, Clone)]
533pub struct LambdaClient {
534 function_arn: String,
536 region: String,
538 timeout: Duration,
540}
541
542impl LambdaClient {
543 #[must_use]
545 pub fn new(function_arn: impl Into<String>, region: impl Into<String>) -> Self {
546 Self {
547 function_arn: function_arn.into(),
548 region: region.into(),
549 timeout: Duration::from_secs(60),
550 }
551 }
552
553 #[must_use]
555 pub fn with_timeout(mut self, timeout: Duration) -> Self {
556 self.timeout = timeout;
557 self
558 }
559
560 #[must_use]
562 pub fn function_arn(&self) -> &str {
563 &self.function_arn
564 }
565
566 #[must_use]
568 pub fn region(&self) -> &str {
569 &self.region
570 }
571
572 #[must_use]
574 pub fn timeout(&self) -> Duration {
575 self.timeout
576 }
577}
578
579#[derive(Debug, Clone, PartialEq, Eq)]
585pub enum ConfigError {
586 MissingField(&'static str),
588 InvalidMemory(u32),
590 InvalidTimeout(u32),
592}
593
594impl std::fmt::Display for ConfigError {
595 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
596 match self {
597 Self::MissingField(field) => write!(f, "Missing required field: {field}"),
598 Self::InvalidMemory(mb) => write!(f, "Invalid memory size: {mb}MB (must be 128-10240)"),
599 Self::InvalidTimeout(secs) => write!(f, "Invalid timeout: {secs}s (must be 1-900)"),
600 }
601 }
602}
603
604impl std::error::Error for ConfigError {}
605
606#[derive(Debug)]
608pub enum DeploymentError {
609 Config(ConfigError),
611 AwsError(String),
613 ModelNotFound(String),
615 PackageTooLarge { size_mb: u64, max_mb: u64 },
617}
618
619impl std::fmt::Display for DeploymentError {
620 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
621 match self {
622 Self::Config(e) => write!(f, "Configuration error: {e}"),
623 Self::AwsError(e) => write!(f, "AWS error: {e}"),
624 Self::ModelNotFound(uri) => write!(f, "Model not found: {uri}"),
625 Self::PackageTooLarge { size_mb, max_mb } => {
626 write!(f, "Package too large: {size_mb}MB (max {max_mb}MB)")
627 }
628 }
629 }
630}
631
632impl std::error::Error for DeploymentError {}
633
634#[cfg(test)]
639#[allow(non_snake_case)]
640#[path = "lambda_tests.rs"]
641mod tests;