Skip to main content

forge_guardrails/server/
budget.rs

1use std::path::Path;
2
3use crate::context::detect_hardware;
4use crate::error::BudgetResolutionError;
5
6use super::lifecycle::LifecycleOptions;
7use super::manager::ServerManager;
8
9/// Budget resolution strategy.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum BudgetMode {
12    /// Use backend-reported context size.
13    Backend,
14    /// User-specified token count.
15    Manual,
16    /// Maximum available context.
17    ForgeFull,
18    /// Half available context (two-phase start for llamaserver).
19    ForgeFast,
20}
21
22impl BudgetMode {
23    /// Returns the string representation of the budget mode.
24    pub fn as_str(&self) -> &'static str {
25        match self {
26            Self::Backend => "backend",
27            Self::Manual => "manual",
28            Self::ForgeFull => "forge-full",
29            Self::ForgeFast => "forge-fast",
30        }
31    }
32
33    /// Parses a string representation into a `BudgetMode` if valid.
34    pub fn parse(s: &str) -> Option<Self> {
35        match s {
36            "backend" => Some(Self::Backend),
37            "manual" => Some(Self::Manual),
38            "forge-full" => Some(Self::ForgeFull),
39            "forge-fast" => Some(Self::ForgeFast),
40            _ => None,
41        }
42    }
43}
44
45impl std::str::FromStr for BudgetMode {
46    type Err = ();
47
48    fn from_str(s: &str) -> Result<Self, Self::Err> {
49        match s {
50            "backend" => Ok(Self::Backend),
51            "manual" => Ok(Self::Manual),
52            "forge-full" => Ok(Self::ForgeFull),
53            "forge-fast" => Ok(Self::ForgeFast),
54            _ => Err(()),
55        }
56    }
57}
58
59impl ServerManager {
60    #[allow(clippy::too_many_arguments)]
61    pub(super) fn start_with_budget_options(
62        &self,
63        model: &str,
64        gguf_path: &Path,
65        mode: &str,
66        budget_mode: BudgetMode,
67        manual_tokens: Option<i64>,
68        extra_flags: &[String],
69        cache_type_k: Option<&str>,
70        cache_type_v: Option<&str>,
71        n_slots: Option<i64>,
72        kv_unified: bool,
73        options: LifecycleOptions,
74    ) -> Result<i64, String> {
75        if budget_mode == BudgetMode::Manual && manual_tokens.is_none() {
76            return Err("manual mode requires manual_tokens".to_string());
77        }
78
79        if self.backend == "ollama" {
80            self.start_with_options(
81                model,
82                gguf_path,
83                mode,
84                extra_flags,
85                None,
86                cache_type_k,
87                cache_type_v,
88                n_slots,
89                kv_unified,
90                options,
91            )
92            .map_err(|e| e.to_string())?;
93            return self
94                .resolve_budget(budget_mode, manual_tokens, n_slots, kv_unified)
95                .map_err(|e| e.to_string());
96        }
97
98        if budget_mode == BudgetMode::ForgeFast {
99            self.start_with_options(
100                model,
101                gguf_path,
102                mode,
103                extra_flags,
104                None,
105                cache_type_k,
106                cache_type_v,
107                n_slots,
108                kv_unified,
109                options,
110            )
111            .map_err(|e| e.to_string())?;
112            let reported_ctx = self.query_props_context().map_err(|e| e.to_string())?;
113            let total_ctx = if kv_unified || n_slots.is_none_or(|slots| slots <= 1) {
114                reported_ctx
115            } else {
116                reported_ctx * n_slots.unwrap_or(1)
117            };
118            let half_total = total_ctx / 2;
119            if let Ok(mut g) = self.last_context.lock() {
120                *g = Some(half_total);
121            }
122            self.start_with_options(
123                model,
124                gguf_path,
125                mode,
126                extra_flags,
127                Some(half_total),
128                cache_type_k,
129                cache_type_v,
130                n_slots,
131                kv_unified,
132                options,
133            )
134            .map_err(|e| e.to_string())?;
135            return self
136                .resolve_budget(budget_mode, manual_tokens, n_slots, kv_unified)
137                .map_err(|e| e.to_string());
138        }
139
140        let ctx_override = if budget_mode == BudgetMode::Manual {
141            manual_tokens
142        } else {
143            None
144        };
145        self.start_with_options(
146            model,
147            gguf_path,
148            mode,
149            extra_flags,
150            ctx_override,
151            cache_type_k,
152            cache_type_v,
153            n_slots,
154            kv_unified,
155            options,
156        )
157        .map_err(|e| e.to_string())?;
158        self.resolve_budget(budget_mode, manual_tokens, n_slots, kv_unified)
159            .map_err(|e| e.to_string())
160    }
161
162    /// Resolve the context budget based on the given mode.
163    pub fn resolve_budget(
164        &self,
165        mode: BudgetMode,
166        manual_tokens: Option<i64>,
167        n_slots: Option<i64>,
168        kv_unified: bool,
169    ) -> Result<i64, BudgetResolutionError> {
170        match mode {
171            BudgetMode::Manual => {
172                if self.backend == "ollama" {
173                    manual_tokens.ok_or_else(|| {
174                        BudgetResolutionError::new()
175                            .with_cause("manual_tokens required for MANUAL budget mode")
176                    })
177                } else {
178                    self.resolve_backend_budget()
179                }
180            }
181            BudgetMode::Backend => self.resolve_backend_budget(),
182            BudgetMode::ForgeFull => self.resolve_forge_full(n_slots, kv_unified),
183            BudgetMode::ForgeFast => self.resolve_forge_fast(n_slots, kv_unified),
184        }
185    }
186
187    fn resolve_backend_budget(&self) -> Result<i64, BudgetResolutionError> {
188        if self.backend == "ollama" {
189            return Ok(Self::ollama_vram_budget());
190        }
191        let ctx = self.query_props_context()?;
192        Ok(ctx)
193    }
194
195    fn resolve_forge_full(
196        &self,
197        _n_slots: Option<i64>,
198        _kv_unified: bool,
199    ) -> Result<i64, BudgetResolutionError> {
200        if self.backend == "ollama" {
201            return Ok(Self::ollama_vram_budget());
202        }
203        self.query_props_context()
204    }
205
206    fn resolve_forge_fast(
207        &self,
208        _n_slots: Option<i64>,
209        _kv_unified: bool,
210    ) -> Result<i64, BudgetResolutionError> {
211        if self.backend == "ollama" {
212            return Ok(Self::ollama_vram_budget() / 2);
213        }
214        self.query_props_context()
215    }
216
217    /// VRAM tier budget for ollama.
218    pub fn ollama_vram_budget() -> i64 {
219        match detect_hardware() {
220            Ok(Some(hw)) => {
221                let gb = hw.vram_total_gb();
222                if gb < 24.0 {
223                    4096
224                } else if gb < 48.0 {
225                    32768
226                } else {
227                    262144
228                }
229            }
230            _ => 4096,
231        }
232    }
233
234    /// Query the backend /props endpoint for the actual context length.
235    pub fn query_props_context(&self) -> Result<i64, BudgetResolutionError> {
236        let url = format!("http://127.0.0.1:{}/props", self.port);
237        let rt = tokio::runtime::Builder::new_current_thread()
238            .enable_io()
239            .enable_time()
240            .build()
241            .map_err(|e| BudgetResolutionError::new().with_cause(e.to_string()))?;
242
243        rt.block_on(async {
244            let resp = reqwest::Client::new()
245                .get(&url)
246                .timeout(std::time::Duration::from_secs(5))
247                .send()
248                .await
249                .map_err(|e| BudgetResolutionError::new().with_cause(e.to_string()))?;
250
251            if !resp.status().is_success() {
252                return Err(
253                    BudgetResolutionError::new().with_cause(format!("Status {}", resp.status()))
254                );
255            }
256
257            let json: serde_json::Value = resp
258                .json()
259                .await
260                .map_err(|e| BudgetResolutionError::new().with_cause(e.to_string()))?;
261
262            let ctx = json
263                .get("default_generation_settings")
264                .and_then(|s| s.get("n_ctx"))
265                .and_then(|n| n.as_i64())
266                .ok_or_else(|| {
267                    BudgetResolutionError::new()
268                        .with_cause("missing context field in /props response")
269                })?;
270
271            Ok(ctx)
272        })
273    }
274}