1use crate::{Result, ServerlessError};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
9pub enum ProviderType {
10 AwsLambda,
12 CloudflareWorkers,
14 GoogleCloudFunctions,
16 AzureFunctions,
18 Custom,
20}
21
22impl ProviderType {
23 pub fn name(&self) -> &'static str {
25 match self {
26 ProviderType::AwsLambda => "AWS Lambda",
27 ProviderType::CloudflareWorkers => "Cloudflare Workers",
28 ProviderType::GoogleCloudFunctions => "Google Cloud Functions",
29 ProviderType::AzureFunctions => "Azure Functions",
30 ProviderType::Custom => "Custom",
31 }
32 }
33
34 pub fn detect() -> Option<Self> {
36 if std::env::var("AWS_LAMBDA_FUNCTION_NAME").is_ok() {
37 Some(ProviderType::AwsLambda)
38 } else if std::env::var("CF_WORKER").is_ok() {
39 Some(ProviderType::CloudflareWorkers)
40 } else if std::env::var("FUNCTION_NAME").is_ok()
41 && std::env::var("GOOGLE_CLOUD_PROJECT").is_ok()
42 {
43 Some(ProviderType::GoogleCloudFunctions)
44 } else if std::env::var("FUNCTIONS_WORKER_RUNTIME").is_ok() {
45 Some(ProviderType::AzureFunctions)
46 } else {
47 None
48 }
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ProviderConfig {
55 pub provider_type: ProviderType,
57 pub memory_mb: u64,
59 pub timeout_seconds: u64,
61 pub gpu_available: bool,
63 pub max_payload_size: u64,
65 pub env_vars: HashMap<String, String>,
67 pub custom: HashMap<String, String>,
69}
70
71impl Default for ProviderConfig {
72 fn default() -> Self {
73 Self {
74 provider_type: ProviderType::Custom,
75 memory_mb: 1024,
76 timeout_seconds: 30,
77 gpu_available: false,
78 max_payload_size: 6 * 1024 * 1024, env_vars: HashMap::new(),
80 custom: HashMap::new(),
81 }
82 }
83}
84
85impl ProviderConfig {
86 pub fn aws_lambda(memory_mb: u64) -> Self {
88 Self {
89 provider_type: ProviderType::AwsLambda,
90 memory_mb,
91 timeout_seconds: 900, gpu_available: false,
93 max_payload_size: 6 * 1024 * 1024,
94 ..Default::default()
95 }
96 }
97
98 pub fn cloudflare_workers() -> Self {
100 Self {
101 provider_type: ProviderType::CloudflareWorkers,
102 memory_mb: 128,
103 timeout_seconds: 30,
104 gpu_available: false,
105 max_payload_size: 100 * 1024 * 1024, ..Default::default()
107 }
108 }
109}
110
111#[derive(Debug)]
113pub struct Provider {
114 config: ProviderConfig,
116 initialized: bool,
118}
119
120impl Provider {
121 pub fn new(config: ProviderConfig) -> Self {
123 Self {
124 config,
125 initialized: false,
126 }
127 }
128
129 pub fn from_env() -> Result<Self> {
131 let provider_type = ProviderType::detect()
132 .ok_or_else(|| ServerlessError::ProviderError("Unknown provider".into()))?;
133
134 let config = match provider_type {
135 ProviderType::AwsLambda => {
136 let memory = std::env::var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE")
137 .ok()
138 .and_then(|s| s.parse().ok())
139 .unwrap_or(1024);
140 ProviderConfig::aws_lambda(memory)
141 }
142 ProviderType::CloudflareWorkers => ProviderConfig::cloudflare_workers(),
143 _ => ProviderConfig::default(),
144 };
145
146 Ok(Self::new(config))
147 }
148
149 pub async fn initialize(&mut self) -> Result<()> {
151 match self.config.provider_type {
153 ProviderType::AwsLambda => {
154 self.init_lambda().await?;
155 }
156 ProviderType::CloudflareWorkers => {
157 self.init_cloudflare().await?;
158 }
159 _ => {}
160 }
161
162 self.initialized = true;
163 Ok(())
164 }
165
166 async fn init_lambda(&self) -> Result<()> {
167 Ok(())
172 }
173
174 async fn init_cloudflare(&self) -> Result<()> {
175 Ok(())
179 }
180
181 pub fn config(&self) -> &ProviderConfig {
183 &self.config
184 }
185
186 pub fn is_initialized(&self) -> bool {
188 self.initialized
189 }
190
191 pub fn remaining_time_ms(&self) -> Option<u64> {
193 match self.config.provider_type {
194 ProviderType::AwsLambda => {
195 Some(self.config.timeout_seconds * 1000)
197 }
198 _ => Some(self.config.timeout_seconds * 1000),
199 }
200 }
201
202 pub fn validate_payload_size(&self, size: usize) -> Result<()> {
204 if size as u64 > self.config.max_payload_size {
205 return Err(ServerlessError::ProviderError(format!(
206 "Payload size {} exceeds limit {}",
207 size, self.config.max_payload_size
208 )));
209 }
210 Ok(())
211 }
212
213 pub fn available_memory_mb(&self) -> u64 {
215 self.config.memory_mb
216 }
217
218 pub fn has_gpu(&self) -> bool {
220 self.config.gpu_available
221 }
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct ProviderCapabilities {
227 pub websocket: bool,
229 pub streaming: bool,
231 pub gpu: bool,
233 pub storage: bool,
235 pub scheduled: bool,
237 pub max_memory_mb: u64,
239 pub max_timeout_seconds: u64,
241}
242
243impl ProviderCapabilities {
244 pub fn for_provider(provider: ProviderType) -> Self {
246 match provider {
247 ProviderType::AwsLambda => Self {
248 websocket: false,
249 streaming: true,
250 gpu: false,
251 storage: false, scheduled: true,
253 max_memory_mb: 10240,
254 max_timeout_seconds: 900,
255 },
256 ProviderType::CloudflareWorkers => Self {
257 websocket: true,
258 streaming: true,
259 gpu: false,
260 storage: true, scheduled: true,
262 max_memory_mb: 128,
263 max_timeout_seconds: 30,
264 },
265 ProviderType::GoogleCloudFunctions => Self {
266 websocket: false,
267 streaming: false,
268 gpu: false,
269 storage: false,
270 scheduled: true,
271 max_memory_mb: 32768,
272 max_timeout_seconds: 3600,
273 },
274 ProviderType::AzureFunctions => Self {
275 websocket: false,
276 streaming: true,
277 gpu: false,
278 storage: false,
279 scheduled: true,
280 max_memory_mb: 14336,
281 max_timeout_seconds: 600,
282 },
283 ProviderType::Custom => Self {
284 websocket: true,
285 streaming: true,
286 gpu: true,
287 storage: true,
288 scheduled: true,
289 max_memory_mb: u64::MAX,
290 max_timeout_seconds: u64::MAX,
291 },
292 }
293 }
294}
295
296#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct RequestContext {
299 pub request_id: String,
301 pub function_name: String,
303 pub invocation_count: u64,
305 pub memory_limit_mb: u64,
307 pub timeout_remaining_ms: u64,
309 pub is_cold_start: bool,
311}
312
313impl RequestContext {
314 pub fn from_lambda_env() -> Option<Self> {
316 Some(Self {
317 request_id: std::env::var("_X_AMZN_TRACE_ID").ok()?,
318 function_name: std::env::var("AWS_LAMBDA_FUNCTION_NAME").ok()?,
319 invocation_count: 0, memory_limit_mb: std::env::var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE")
321 .ok()?
322 .parse()
323 .ok()?,
324 timeout_remaining_ms: 0, is_cold_start: false, })
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn test_provider_type_name() {
336 assert_eq!(ProviderType::AwsLambda.name(), "AWS Lambda");
337 assert_eq!(ProviderType::CloudflareWorkers.name(), "Cloudflare Workers");
338 }
339
340 #[test]
341 fn test_config_default() {
342 let config = ProviderConfig::default();
343 assert_eq!(config.provider_type, ProviderType::Custom);
344 assert_eq!(config.memory_mb, 1024);
345 }
346
347 #[test]
348 fn test_aws_lambda_config() {
349 let config = ProviderConfig::aws_lambda(2048);
350 assert_eq!(config.provider_type, ProviderType::AwsLambda);
351 assert_eq!(config.memory_mb, 2048);
352 assert_eq!(config.timeout_seconds, 900);
353 }
354
355 #[test]
356 fn test_cloudflare_config() {
357 let config = ProviderConfig::cloudflare_workers();
358 assert_eq!(config.provider_type, ProviderType::CloudflareWorkers);
359 assert_eq!(config.memory_mb, 128);
360 }
361
362 #[test]
363 fn test_provider_creation() {
364 let config = ProviderConfig::default();
365 let provider = Provider::new(config);
366
367 assert!(!provider.is_initialized());
368 assert_eq!(provider.available_memory_mb(), 1024);
369 }
370
371 #[test]
372 fn test_payload_validation() {
373 let config = ProviderConfig {
374 max_payload_size: 1024,
375 ..Default::default()
376 };
377 let provider = Provider::new(config);
378
379 assert!(provider.validate_payload_size(512).is_ok());
380 assert!(provider.validate_payload_size(2048).is_err());
381 }
382
383 #[test]
384 fn test_capabilities() {
385 let lambda_caps = ProviderCapabilities::for_provider(ProviderType::AwsLambda);
386 assert!(!lambda_caps.websocket);
387 assert!(lambda_caps.streaming);
388 assert_eq!(lambda_caps.max_memory_mb, 10240);
389
390 let cf_caps = ProviderCapabilities::for_provider(ProviderType::CloudflareWorkers);
391 assert!(cf_caps.websocket);
392 assert!(cf_caps.storage);
393 }
394}