Skip to main content

haagenti_serverless/
provider.rs

1//! Multi-provider support for serverless deployment
2
3use crate::{Result, ServerlessError};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7/// Provider type
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
9pub enum ProviderType {
10    /// AWS Lambda
11    AwsLambda,
12    /// Cloudflare Workers
13    CloudflareWorkers,
14    /// Google Cloud Functions
15    GoogleCloudFunctions,
16    /// Azure Functions
17    AzureFunctions,
18    /// Custom/Self-hosted
19    Custom,
20}
21
22impl ProviderType {
23    /// Get provider name
24    pub fn name(&self) -> &'static str {
25        match self {
26            ProviderType::AwsLambda => "AWS Lambda",
27            ProviderType::CloudflareWorkers => "Cloudflare Workers",
28            ProviderType::GoogleCloudFunctions => "Google Cloud Functions",
29            ProviderType::AzureFunctions => "Azure Functions",
30            ProviderType::Custom => "Custom",
31        }
32    }
33
34    /// Detect from environment
35    pub fn detect() -> Option<Self> {
36        if std::env::var("AWS_LAMBDA_FUNCTION_NAME").is_ok() {
37            Some(ProviderType::AwsLambda)
38        } else if std::env::var("CF_WORKER").is_ok() {
39            Some(ProviderType::CloudflareWorkers)
40        } else if std::env::var("FUNCTION_NAME").is_ok()
41            && std::env::var("GOOGLE_CLOUD_PROJECT").is_ok()
42        {
43            Some(ProviderType::GoogleCloudFunctions)
44        } else if std::env::var("FUNCTIONS_WORKER_RUNTIME").is_ok() {
45            Some(ProviderType::AzureFunctions)
46        } else {
47            None
48        }
49    }
50}
51
52/// Provider configuration
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ProviderConfig {
55    /// Provider type
56    pub provider_type: ProviderType,
57    /// Memory limit in MB
58    pub memory_mb: u64,
59    /// Timeout in seconds
60    pub timeout_seconds: u64,
61    /// GPU available
62    pub gpu_available: bool,
63    /// Maximum payload size in bytes
64    pub max_payload_size: u64,
65    /// Environment variables
66    pub env_vars: HashMap<String, String>,
67    /// Custom settings
68    pub custom: HashMap<String, String>,
69}
70
71impl Default for ProviderConfig {
72    fn default() -> Self {
73        Self {
74            provider_type: ProviderType::Custom,
75            memory_mb: 1024,
76            timeout_seconds: 30,
77            gpu_available: false,
78            max_payload_size: 6 * 1024 * 1024, // 6MB
79            env_vars: HashMap::new(),
80            custom: HashMap::new(),
81        }
82    }
83}
84
85impl ProviderConfig {
86    /// Create config for AWS Lambda
87    pub fn aws_lambda(memory_mb: u64) -> Self {
88        Self {
89            provider_type: ProviderType::AwsLambda,
90            memory_mb,
91            timeout_seconds: 900, // 15 minutes max
92            gpu_available: false,
93            max_payload_size: 6 * 1024 * 1024,
94            ..Default::default()
95        }
96    }
97
98    /// Create config for Cloudflare Workers
99    pub fn cloudflare_workers() -> Self {
100        Self {
101            provider_type: ProviderType::CloudflareWorkers,
102            memory_mb: 128,
103            timeout_seconds: 30,
104            gpu_available: false,
105            max_payload_size: 100 * 1024 * 1024, // 100MB
106            ..Default::default()
107        }
108    }
109}
110
111/// Provider abstraction
112#[derive(Debug)]
113pub struct Provider {
114    /// Configuration
115    config: ProviderConfig,
116    /// Initialized
117    initialized: bool,
118}
119
120impl Provider {
121    /// Create new provider
122    pub fn new(config: ProviderConfig) -> Self {
123        Self {
124            config,
125            initialized: false,
126        }
127    }
128
129    /// Detect and create from environment
130    pub fn from_env() -> Result<Self> {
131        let provider_type = ProviderType::detect()
132            .ok_or_else(|| ServerlessError::ProviderError("Unknown provider".into()))?;
133
134        let config = match provider_type {
135            ProviderType::AwsLambda => {
136                let memory = std::env::var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE")
137                    .ok()
138                    .and_then(|s| s.parse().ok())
139                    .unwrap_or(1024);
140                ProviderConfig::aws_lambda(memory)
141            }
142            ProviderType::CloudflareWorkers => ProviderConfig::cloudflare_workers(),
143            _ => ProviderConfig::default(),
144        };
145
146        Ok(Self::new(config))
147    }
148
149    /// Initialize provider
150    pub async fn initialize(&mut self) -> Result<()> {
151        // Provider-specific initialization
152        match self.config.provider_type {
153            ProviderType::AwsLambda => {
154                self.init_lambda().await?;
155            }
156            ProviderType::CloudflareWorkers => {
157                self.init_cloudflare().await?;
158            }
159            _ => {}
160        }
161
162        self.initialized = true;
163        Ok(())
164    }
165
166    async fn init_lambda(&self) -> Result<()> {
167        // Lambda-specific initialization
168        // - Set up X-Ray tracing
169        // - Configure memory allocator
170        // - Initialize extensions
171        Ok(())
172    }
173
174    async fn init_cloudflare(&self) -> Result<()> {
175        // Cloudflare-specific initialization
176        // - Configure KV access
177        // - Set up Workers AI bindings
178        Ok(())
179    }
180
181    /// Get configuration
182    pub fn config(&self) -> &ProviderConfig {
183        &self.config
184    }
185
186    /// Check if initialized
187    pub fn is_initialized(&self) -> bool {
188        self.initialized
189    }
190
191    /// Get remaining execution time (ms)
192    pub fn remaining_time_ms(&self) -> Option<u64> {
193        match self.config.provider_type {
194            ProviderType::AwsLambda => {
195                // Lambda provides this via context
196                Some(self.config.timeout_seconds * 1000)
197            }
198            _ => Some(self.config.timeout_seconds * 1000),
199        }
200    }
201
202    /// Check if request size is within limits
203    pub fn validate_payload_size(&self, size: usize) -> Result<()> {
204        if size as u64 > self.config.max_payload_size {
205            return Err(ServerlessError::ProviderError(format!(
206                "Payload size {} exceeds limit {}",
207                size, self.config.max_payload_size
208            )));
209        }
210        Ok(())
211    }
212
213    /// Get available memory
214    pub fn available_memory_mb(&self) -> u64 {
215        self.config.memory_mb
216    }
217
218    /// Check GPU availability
219    pub fn has_gpu(&self) -> bool {
220        self.config.gpu_available
221    }
222}
223
224/// Provider capabilities
225#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct ProviderCapabilities {
227    /// Supports WebSocket
228    pub websocket: bool,
229    /// Supports streaming responses
230    pub streaming: bool,
231    /// Supports GPU
232    pub gpu: bool,
233    /// Supports persistent storage
234    pub storage: bool,
235    /// Supports scheduled execution
236    pub scheduled: bool,
237    /// Maximum memory MB
238    pub max_memory_mb: u64,
239    /// Maximum timeout seconds
240    pub max_timeout_seconds: u64,
241}
242
243impl ProviderCapabilities {
244    /// Get capabilities for provider type
245    pub fn for_provider(provider: ProviderType) -> Self {
246        match provider {
247            ProviderType::AwsLambda => Self {
248                websocket: false,
249                streaming: true,
250                gpu: false,
251                storage: false, // Need S3
252                scheduled: true,
253                max_memory_mb: 10240,
254                max_timeout_seconds: 900,
255            },
256            ProviderType::CloudflareWorkers => Self {
257                websocket: true,
258                streaming: true,
259                gpu: false,
260                storage: true, // KV, R2, D1
261                scheduled: true,
262                max_memory_mb: 128,
263                max_timeout_seconds: 30,
264            },
265            ProviderType::GoogleCloudFunctions => Self {
266                websocket: false,
267                streaming: false,
268                gpu: false,
269                storage: false,
270                scheduled: true,
271                max_memory_mb: 32768,
272                max_timeout_seconds: 3600,
273            },
274            ProviderType::AzureFunctions => Self {
275                websocket: false,
276                streaming: true,
277                gpu: false,
278                storage: false,
279                scheduled: true,
280                max_memory_mb: 14336,
281                max_timeout_seconds: 600,
282            },
283            ProviderType::Custom => Self {
284                websocket: true,
285                streaming: true,
286                gpu: true,
287                storage: true,
288                scheduled: true,
289                max_memory_mb: u64::MAX,
290                max_timeout_seconds: u64::MAX,
291            },
292        }
293    }
294}
295
296/// Request context from provider
297#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct RequestContext {
299    /// Request ID
300    pub request_id: String,
301    /// Function name
302    pub function_name: String,
303    /// Invocation count
304    pub invocation_count: u64,
305    /// Memory limit MB
306    pub memory_limit_mb: u64,
307    /// Timeout remaining ms
308    pub timeout_remaining_ms: u64,
309    /// Is cold start
310    pub is_cold_start: bool,
311}
312
313impl RequestContext {
314    /// Create from AWS Lambda context
315    pub fn from_lambda_env() -> Option<Self> {
316        Some(Self {
317            request_id: std::env::var("_X_AMZN_TRACE_ID").ok()?,
318            function_name: std::env::var("AWS_LAMBDA_FUNCTION_NAME").ok()?,
319            invocation_count: 0, // Not available in env
320            memory_limit_mb: std::env::var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE")
321                .ok()?
322                .parse()
323                .ok()?,
324            timeout_remaining_ms: 0, // Set from context
325            is_cold_start: false,    // Determined at runtime
326        })
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn test_provider_type_name() {
336        assert_eq!(ProviderType::AwsLambda.name(), "AWS Lambda");
337        assert_eq!(ProviderType::CloudflareWorkers.name(), "Cloudflare Workers");
338    }
339
340    #[test]
341    fn test_config_default() {
342        let config = ProviderConfig::default();
343        assert_eq!(config.provider_type, ProviderType::Custom);
344        assert_eq!(config.memory_mb, 1024);
345    }
346
347    #[test]
348    fn test_aws_lambda_config() {
349        let config = ProviderConfig::aws_lambda(2048);
350        assert_eq!(config.provider_type, ProviderType::AwsLambda);
351        assert_eq!(config.memory_mb, 2048);
352        assert_eq!(config.timeout_seconds, 900);
353    }
354
355    #[test]
356    fn test_cloudflare_config() {
357        let config = ProviderConfig::cloudflare_workers();
358        assert_eq!(config.provider_type, ProviderType::CloudflareWorkers);
359        assert_eq!(config.memory_mb, 128);
360    }
361
362    #[test]
363    fn test_provider_creation() {
364        let config = ProviderConfig::default();
365        let provider = Provider::new(config);
366
367        assert!(!provider.is_initialized());
368        assert_eq!(provider.available_memory_mb(), 1024);
369    }
370
371    #[test]
372    fn test_payload_validation() {
373        let config = ProviderConfig {
374            max_payload_size: 1024,
375            ..Default::default()
376        };
377        let provider = Provider::new(config);
378
379        assert!(provider.validate_payload_size(512).is_ok());
380        assert!(provider.validate_payload_size(2048).is_err());
381    }
382
383    #[test]
384    fn test_capabilities() {
385        let lambda_caps = ProviderCapabilities::for_provider(ProviderType::AwsLambda);
386        assert!(!lambda_caps.websocket);
387        assert!(lambda_caps.streaming);
388        assert_eq!(lambda_caps.max_memory_mb, 10240);
389
390        let cf_caps = ProviderCapabilities::for_provider(ProviderType::CloudflareWorkers);
391        assert!(cf_caps.websocket);
392        assert!(cf_caps.storage);
393    }
394}