Skip to main content

rs_adk/optimization/
optimizer.rs

1//! Base optimizer trait and result types.
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5
6/// Errors from optimization operations.
7#[derive(Debug, thiserror::Error)]
8pub enum OptimizerError {
9    /// Sampling failed.
10    #[error("Sampling error: {0}")]
11    Sampling(String),
12    /// Evaluation failed.
13    #[error("Evaluation error: {0}")]
14    Evaluation(String),
15    /// LLM generation failed.
16    #[error("LLM error: {0}")]
17    Llm(String),
18    /// Optimization logic error.
19    #[error("Optimization error: {0}")]
20    Optimization(String),
21}
22
23/// Result of an optimization run.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct OptimizerResult {
26    /// The best instruction found during optimization.
27    pub best_instruction: String,
28    /// Score of the best instruction on the validation set.
29    pub best_score: f64,
30    /// Number of iterations performed.
31    pub iterations: usize,
32    /// Score history across iterations (iteration_number, score).
33    pub score_history: Vec<(usize, f64)>,
34}
35
36/// Trait for agent optimizers that iteratively improve agent instructions.
37///
38/// Mirrors ADK-Python's `AgentOptimizer` abstract class.
39#[async_trait]
40pub trait AgentOptimizer: Send + Sync {
41    /// Run the optimization process.
42    ///
43    /// # Arguments
44    /// * `initial_instruction` — The starting agent instruction to optimize.
45    /// * `model_id` — The model to use for the agent being optimized.
46    ///
47    /// # Returns
48    /// An [`OptimizerResult`] with the best instruction and scores.
49    async fn optimize(
50        &self,
51        initial_instruction: &str,
52        model_id: &str,
53    ) -> Result<OptimizerResult, OptimizerError>;
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59
60    fn _assert_object_safe(_: &dyn AgentOptimizer) {}
61
62    #[test]
63    fn optimizer_result_serde() {
64        let result = OptimizerResult {
65            best_instruction: "Be helpful".into(),
66            best_score: 0.9,
67            iterations: 5,
68            score_history: vec![(0, 0.5), (1, 0.7), (2, 0.9)],
69        };
70        let json = serde_json::to_string(&result).unwrap();
71        let deserialized: OptimizerResult = serde_json::from_str(&json).unwrap();
72        assert!((deserialized.best_score - 0.9).abs() < f64::EPSILON);
73    }
74}