use std::collections::HashMap;
use std::time::Duration;
pub struct TimeEstimator {
tool_durations: HashMap<String, Duration>,
}
impl TimeEstimator {
pub fn new() -> Self {
let mut tool_durations = HashMap::new();
tool_durations.insert("http".to_string(), Duration::from_secs(5));
tool_durations.insert("echo".to_string(), Duration::from_millis(10));
tool_durations.insert("time".to_string(), Duration::from_millis(1));
tool_durations.insert("json".to_string(), Duration::from_millis(5));
Self { tool_durations }
}
pub fn estimate_tool(&self, tool_name: &str) -> Duration {
self.tool_durations
.get(tool_name)
.copied()
.unwrap_or(Duration::from_secs(5)) }
pub fn estimate_llm_response(&self, estimated_tokens: u32) -> Duration {
let seconds = estimated_tokens as f64 / 50.0;
Duration::from_secs_f64(seconds.max(1.0))
}
pub fn set_tool_duration(&mut self, tool_name: impl Into<String>, duration: Duration) {
self.tool_durations.insert(tool_name.into(), duration);
}
pub fn all_tool_durations(&self) -> &HashMap<String, Duration> {
&self.tool_durations
}
}
impl Default for TimeEstimator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_time_estimation() {
let estimator = TimeEstimator::new();
assert!(estimator.estimate_tool("echo") < Duration::from_secs(1));
assert!(estimator.estimate_tool("http") >= Duration::from_secs(1));
}
#[test]
fn test_llm_time_estimation() {
let estimator = TimeEstimator::new();
let duration = estimator.estimate_llm_response(500);
assert!(duration >= Duration::from_secs(1));
}
}