1pub mod memory;
15pub mod merge;
16pub mod optimizer;
17
18pub use memory::{MemoryPlanner, MemoryRequirement};
19pub use merge::MergeEngine;
20pub use optimizer::{LoraOptimizer, OptimalConfig};
21
22use entrenar_common::Result;
23
24pub fn plan(model_params: u64, available_vram_gb: f64, method: Method) -> Result<OptimalConfig> {
26 LoraOptimizer::new(model_params, available_vram_gb).optimize(method)
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum Method {
32 Full,
34 LoRA,
36 QLoRA,
38 Auto,
40}
41
42impl std::str::FromStr for Method {
43 type Err = String;
44
45 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
46 match s.to_lowercase().as_str() {
47 "full" => Ok(Self::Full),
48 "lora" => Ok(Self::LoRA),
49 "qlora" => Ok(Self::QLoRA),
50 "auto" => Ok(Self::Auto),
51 _ => Err(format!("Unknown method: {s}. Use: full, lora, qlora, auto")),
52 }
53 }
54}
55
56#[cfg(test)]
57mod tests {
58 use super::*;
59
60 #[test]
61 fn test_method_parsing() {
62 assert_eq!("lora".parse::<Method>().expect("parsing should succeed"), Method::LoRA);
63 assert_eq!("QLoRA".parse::<Method>().expect("parsing should succeed"), Method::QLoRA);
64 assert_eq!("AUTO".parse::<Method>().expect("parsing should succeed"), Method::Auto);
65 }
66
67 #[test]
68 fn test_plan_returns_config() {
69 let config = plan(7_000_000_000, 16.0, Method::Auto);
70 assert!(config.is_ok());
71 }
72
73 #[test]
74 fn test_method_parsing_full() {
75 assert_eq!("full".parse::<Method>().expect("parsing should succeed"), Method::Full);
76 assert_eq!("FULL".parse::<Method>().expect("parsing should succeed"), Method::Full);
77 }
78
79 #[test]
80 fn test_method_parsing_invalid() {
81 let result = "invalid".parse::<Method>();
82 assert!(result.is_err());
83 assert!(result.unwrap_err().contains("Unknown method"));
84 }
85
86 #[test]
87 fn test_method_equality() {
88 assert_eq!(Method::LoRA, Method::LoRA);
89 assert_ne!(Method::LoRA, Method::QLoRA);
90 }
91
92 #[test]
93 fn test_plan_with_specific_methods() {
94 let lora = plan(7_000_000_000, 16.0, Method::LoRA);
95 assert!(lora.is_ok());
96
97 let qlora = plan(7_000_000_000, 16.0, Method::QLoRA);
98 assert!(qlora.is_ok());
99
100 let full = plan(7_000_000_000, 80.0, Method::Full);
101 assert!(full.is_ok());
102 }
103}