use crate::priority::{Priority, WorkloadType};
use serde::{Deserialize, Serialize};
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct Allocation {
pub id: String,
pub workload_type: WorkloadType,
pub priority: Priority,
pub memory_allocated: u64,
pub quality_target: f32,
pub created_at: Instant,
}
impl Allocation {
pub fn age(&self) -> std::time::Duration {
self.created_at.elapsed()
}
pub fn is_llm(&self) -> bool {
matches!(self.workload_type, WorkloadType::LlmInference)
}
pub fn is_diffusion(&self) -> bool {
matches!(
self.workload_type,
WorkloadType::ImageGeneration | WorkloadType::VideoGeneration
)
}
pub fn memory_mb(&self) -> u64 {
self.memory_allocated / (1024 * 1024)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AllocationRequest {
pub workload_type: WorkloadType,
pub priority: Priority,
pub memory_required: u64,
pub min_quality: Option<f32>,
pub wait_for_resources: bool,
pub timeout_ms: Option<u64>,
pub metadata: Option<String>,
}
impl AllocationRequest {
pub fn llm(memory_required: u64) -> Self {
Self {
workload_type: WorkloadType::LlmInference,
priority: Priority::Normal,
memory_required,
min_quality: None,
wait_for_resources: true,
timeout_ms: Some(30_000),
metadata: None,
}
}
pub fn image(memory_required: u64) -> Self {
Self {
workload_type: WorkloadType::ImageGeneration,
priority: Priority::Normal,
memory_required,
min_quality: None,
wait_for_resources: true,
timeout_ms: Some(60_000),
metadata: None,
}
}
pub fn video(memory_required: u64) -> Self {
Self {
workload_type: WorkloadType::VideoGeneration,
priority: Priority::Normal,
memory_required,
min_quality: None,
wait_for_resources: true,
timeout_ms: Some(120_000),
metadata: None,
}
}
pub fn with_priority(mut self, priority: Priority) -> Self {
self.priority = priority;
self
}
pub fn with_min_quality(mut self, min_quality: f32) -> Self {
self.min_quality = Some(min_quality.clamp(0.0, 1.0));
self
}
pub fn no_wait(mut self) -> Self {
self.wait_for_resources = false;
self.timeout_ms = None;
self
}
pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
self.timeout_ms = Some(timeout_ms);
self
}
pub fn with_metadata(mut self, metadata: impl Into<String>) -> Self {
self.metadata = Some(metadata.into());
self
}
pub fn effective_min_quality(&self) -> f32 {
self.min_quality
.unwrap_or_else(|| self.workload_type.min_quality())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AllocationResult {
Success {
quality: f32,
memory: u64,
},
InsufficientMemory {
requested: u64,
available: u64,
},
InsufficientQuality {
requested: f32,
achievable: f32,
},
Timeout {
waited_ms: u64,
},
Preempted,
}
impl AllocationResult {
pub fn is_success(&self) -> bool {
matches!(self, Self::Success { .. })
}
pub fn quality(&self) -> Option<f32> {
match self {
Self::Success { quality, .. } => Some(*quality),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_allocation_age() {
let alloc = Allocation {
id: "test".to_string(),
workload_type: WorkloadType::LlmInference,
priority: Priority::Normal,
memory_allocated: 1024,
quality_target: 1.0,
created_at: Instant::now(),
};
std::thread::sleep(std::time::Duration::from_millis(10));
assert!(alloc.age().as_millis() >= 10);
}
#[test]
fn test_request_builder() {
let req = AllocationRequest::llm(1024 * 1024 * 1024)
.with_priority(Priority::High)
.with_min_quality(0.8)
.with_metadata("test inference");
assert!(matches!(req.workload_type, WorkloadType::LlmInference));
assert!(matches!(req.priority, Priority::High));
assert_eq!(req.min_quality, Some(0.8));
assert!(req.wait_for_resources);
}
#[test]
fn test_effective_min_quality() {
let req = AllocationRequest::llm(1024);
assert!((req.effective_min_quality() - 0.4).abs() < 0.001);
let req_with_min = req.with_min_quality(0.6);
assert!((req_with_min.effective_min_quality() - 0.6).abs() < 0.001);
}
#[test]
fn test_allocation_result() {
let success = AllocationResult::Success {
quality: 0.9,
memory: 1024,
};
assert!(success.is_success());
assert_eq!(success.quality(), Some(0.9));
let failure = AllocationResult::InsufficientMemory {
requested: 1000,
available: 500,
};
assert!(!failure.is_success());
assert_eq!(failure.quality(), None);
}
}