use crate::{Result, ServerlessError};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ProviderType {
AwsLambda,
CloudflareWorkers,
GoogleCloudFunctions,
AzureFunctions,
Custom,
}
impl ProviderType {
pub fn name(&self) -> &'static str {
match self {
ProviderType::AwsLambda => "AWS Lambda",
ProviderType::CloudflareWorkers => "Cloudflare Workers",
ProviderType::GoogleCloudFunctions => "Google Cloud Functions",
ProviderType::AzureFunctions => "Azure Functions",
ProviderType::Custom => "Custom",
}
}
pub fn detect() -> Option<Self> {
if std::env::var("AWS_LAMBDA_FUNCTION_NAME").is_ok() {
Some(ProviderType::AwsLambda)
} else if std::env::var("CF_WORKER").is_ok() {
Some(ProviderType::CloudflareWorkers)
} else if std::env::var("FUNCTION_NAME").is_ok()
&& std::env::var("GOOGLE_CLOUD_PROJECT").is_ok()
{
Some(ProviderType::GoogleCloudFunctions)
} else if std::env::var("FUNCTIONS_WORKER_RUNTIME").is_ok() {
Some(ProviderType::AzureFunctions)
} else {
None
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
pub provider_type: ProviderType,
pub memory_mb: u64,
pub timeout_seconds: u64,
pub gpu_available: bool,
pub max_payload_size: u64,
pub env_vars: HashMap<String, String>,
pub custom: HashMap<String, String>,
}
impl Default for ProviderConfig {
fn default() -> Self {
Self {
provider_type: ProviderType::Custom,
memory_mb: 1024,
timeout_seconds: 30,
gpu_available: false,
max_payload_size: 6 * 1024 * 1024, env_vars: HashMap::new(),
custom: HashMap::new(),
}
}
}
impl ProviderConfig {
pub fn aws_lambda(memory_mb: u64) -> Self {
Self {
provider_type: ProviderType::AwsLambda,
memory_mb,
timeout_seconds: 900, gpu_available: false,
max_payload_size: 6 * 1024 * 1024,
..Default::default()
}
}
pub fn cloudflare_workers() -> Self {
Self {
provider_type: ProviderType::CloudflareWorkers,
memory_mb: 128,
timeout_seconds: 30,
gpu_available: false,
max_payload_size: 100 * 1024 * 1024, ..Default::default()
}
}
}
#[derive(Debug)]
pub struct Provider {
config: ProviderConfig,
initialized: bool,
}
impl Provider {
pub fn new(config: ProviderConfig) -> Self {
Self {
config,
initialized: false,
}
}
pub fn from_env() -> Result<Self> {
let provider_type = ProviderType::detect()
.ok_or_else(|| ServerlessError::ProviderError("Unknown provider".into()))?;
let config = match provider_type {
ProviderType::AwsLambda => {
let memory = std::env::var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(1024);
ProviderConfig::aws_lambda(memory)
}
ProviderType::CloudflareWorkers => ProviderConfig::cloudflare_workers(),
_ => ProviderConfig::default(),
};
Ok(Self::new(config))
}
pub async fn initialize(&mut self) -> Result<()> {
match self.config.provider_type {
ProviderType::AwsLambda => {
self.init_lambda().await?;
}
ProviderType::CloudflareWorkers => {
self.init_cloudflare().await?;
}
_ => {}
}
self.initialized = true;
Ok(())
}
async fn init_lambda(&self) -> Result<()> {
Ok(())
}
async fn init_cloudflare(&self) -> Result<()> {
Ok(())
}
pub fn config(&self) -> &ProviderConfig {
&self.config
}
pub fn is_initialized(&self) -> bool {
self.initialized
}
pub fn remaining_time_ms(&self) -> Option<u64> {
match self.config.provider_type {
ProviderType::AwsLambda => {
Some(self.config.timeout_seconds * 1000)
}
_ => Some(self.config.timeout_seconds * 1000),
}
}
pub fn validate_payload_size(&self, size: usize) -> Result<()> {
if size as u64 > self.config.max_payload_size {
return Err(ServerlessError::ProviderError(format!(
"Payload size {} exceeds limit {}",
size, self.config.max_payload_size
)));
}
Ok(())
}
pub fn available_memory_mb(&self) -> u64 {
self.config.memory_mb
}
pub fn has_gpu(&self) -> bool {
self.config.gpu_available
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderCapabilities {
pub websocket: bool,
pub streaming: bool,
pub gpu: bool,
pub storage: bool,
pub scheduled: bool,
pub max_memory_mb: u64,
pub max_timeout_seconds: u64,
}
impl ProviderCapabilities {
pub fn for_provider(provider: ProviderType) -> Self {
match provider {
ProviderType::AwsLambda => Self {
websocket: false,
streaming: true,
gpu: false,
storage: false, scheduled: true,
max_memory_mb: 10240,
max_timeout_seconds: 900,
},
ProviderType::CloudflareWorkers => Self {
websocket: true,
streaming: true,
gpu: false,
storage: true, scheduled: true,
max_memory_mb: 128,
max_timeout_seconds: 30,
},
ProviderType::GoogleCloudFunctions => Self {
websocket: false,
streaming: false,
gpu: false,
storage: false,
scheduled: true,
max_memory_mb: 32768,
max_timeout_seconds: 3600,
},
ProviderType::AzureFunctions => Self {
websocket: false,
streaming: true,
gpu: false,
storage: false,
scheduled: true,
max_memory_mb: 14336,
max_timeout_seconds: 600,
},
ProviderType::Custom => Self {
websocket: true,
streaming: true,
gpu: true,
storage: true,
scheduled: true,
max_memory_mb: u64::MAX,
max_timeout_seconds: u64::MAX,
},
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestContext {
pub request_id: String,
pub function_name: String,
pub invocation_count: u64,
pub memory_limit_mb: u64,
pub timeout_remaining_ms: u64,
pub is_cold_start: bool,
}
impl RequestContext {
pub fn from_lambda_env() -> Option<Self> {
Some(Self {
request_id: std::env::var("_X_AMZN_TRACE_ID").ok()?,
function_name: std::env::var("AWS_LAMBDA_FUNCTION_NAME").ok()?,
invocation_count: 0, memory_limit_mb: std::env::var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE")
.ok()?
.parse()
.ok()?,
timeout_remaining_ms: 0, is_cold_start: false, })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_type_name() {
assert_eq!(ProviderType::AwsLambda.name(), "AWS Lambda");
assert_eq!(ProviderType::CloudflareWorkers.name(), "Cloudflare Workers");
}
#[test]
fn test_config_default() {
let config = ProviderConfig::default();
assert_eq!(config.provider_type, ProviderType::Custom);
assert_eq!(config.memory_mb, 1024);
}
#[test]
fn test_aws_lambda_config() {
let config = ProviderConfig::aws_lambda(2048);
assert_eq!(config.provider_type, ProviderType::AwsLambda);
assert_eq!(config.memory_mb, 2048);
assert_eq!(config.timeout_seconds, 900);
}
#[test]
fn test_cloudflare_config() {
let config = ProviderConfig::cloudflare_workers();
assert_eq!(config.provider_type, ProviderType::CloudflareWorkers);
assert_eq!(config.memory_mb, 128);
}
#[test]
fn test_provider_creation() {
let config = ProviderConfig::default();
let provider = Provider::new(config);
assert!(!provider.is_initialized());
assert_eq!(provider.available_memory_mb(), 1024);
}
#[test]
fn test_payload_validation() {
let config = ProviderConfig {
max_payload_size: 1024,
..Default::default()
};
let provider = Provider::new(config);
assert!(provider.validate_payload_size(512).is_ok());
assert!(provider.validate_payload_size(2048).is_err());
}
#[test]
fn test_capabilities() {
let lambda_caps = ProviderCapabilities::for_provider(ProviderType::AwsLambda);
assert!(!lambda_caps.websocket);
assert!(lambda_caps.streaming);
assert_eq!(lambda_caps.max_memory_mb, 10240);
let cf_caps = ProviderCapabilities::for_provider(ProviderType::CloudflareWorkers);
assert!(cf_caps.websocket);
assert!(cf_caps.storage);
}
}