Skip to main content

synth_ai/
optimization.rs

1use std::collections::HashMap;
2use std::fs;
3use std::path::Path;
4
5use serde::{Deserialize, Serialize};
6use serde_json::{json, Map, Value};
7
8use crate::client::{AuthStyle, SynthClient};
9use crate::sse::{stream_sse, SseStream};
10use crate::types::{Result, SynthError};
11
12#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
13#[serde(rename_all = "lowercase")]
14pub enum Algorithm {
15    Gepa,
16    Mipro,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct PolicyOptimizationJobConfig {
21    pub config: Value,
22}
23
24impl PolicyOptimizationJobConfig {
25    pub fn from_json(config: Value) -> Self {
26        Self { config }
27    }
28
29    pub fn from_toml_str(input: &str) -> Result<Self> {
30        let value: toml::Value =
31            toml::from_str(input).map_err(|err| SynthError::UnexpectedResponse(err.to_string()))?;
32        let config = serde_json::to_value(value)?;
33        Ok(Self { config })
34    }
35
36    pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self> {
37        let content = fs::read_to_string(path)?;
38        Self::from_toml_str(&content)
39    }
40
41    pub fn to_payload(&self) -> Value {
42        let mut config = self.config.clone();
43        if let Value::Object(ref mut obj) = config {
44            if let Some(policy_opt) = obj.remove("policy_optimization") {
45                obj.insert("prompt_learning".to_string(), policy_opt);
46            }
47            if let Some(Value::Object(pl)) = obj.get_mut("prompt_learning") {
48                if let Some(local_url) = pl.remove("localapi_url") {
49                    pl.insert("task_app_url".to_string(), local_url);
50                }
51            }
52        }
53        config
54    }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize, Default)]
58pub struct PromptLearningResults {
59    pub best_prompt: Option<Value>,
60    pub best_score: Option<f64>,
61    pub top_prompts: Vec<Value>,
62    pub optimized_candidates: Vec<Value>,
63    pub attempted_candidates: Vec<Value>,
64    pub validation_results: Vec<Value>,
65}
66
67#[derive(Clone)]
68pub struct PolicyOptimizationJob {
69    client: SynthClient,
70    job_id: String,
71}
72
73impl PolicyOptimizationJob {
74    pub fn new(client: SynthClient, job_id: impl Into<String>) -> Self {
75        Self {
76            client,
77            job_id: job_id.into(),
78        }
79    }
80
81    pub fn job_id(&self) -> &str {
82        &self.job_id
83    }
84
85    pub async fn submit(client: SynthClient, config: &PolicyOptimizationJobConfig) -> Result<Self> {
86        let payload = config.to_payload();
87        let algorithm = payload
88            .get("prompt_learning")
89            .and_then(|v| v.get("algorithm"))
90            .and_then(|v| v.as_str())
91            .unwrap_or("gepa");
92        let submit_body = json!({
93            "algorithm": algorithm,
94            "config_body": payload,
95        });
96        let resp = client
97            .post_json_fallback(
98                &[
99                    "/policy-optimization/online/jobs",
100                    "/prompt-learning/online/jobs",
101                ],
102                &submit_body,
103                AuthStyle::Both,
104            )
105            .await?;
106        let job_id = resp
107            .get("job_id")
108            .and_then(|v| v.as_str())
109            .ok_or_else(|| SynthError::UnexpectedResponse("missing job_id".to_string()))?;
110        Ok(Self::new(client, job_id))
111    }
112
113    pub async fn status(&self) -> Result<Value> {
114        let path = format!(
115            "/policy-optimization/online/jobs/{}",
116            self.job_id
117        );
118        let fallback = format!("/prompt-learning/online/jobs/{}", self.job_id);
119        self.client
120            .get_json_fallback(
121                &[path.as_str(), fallback.as_str()],
122                AuthStyle::Both,
123            )
124            .await
125    }
126
127    pub async fn events(&self) -> Result<Vec<Value>> {
128        let path = format!(
129            "/policy-optimization/online/jobs/{}/events",
130            self.job_id
131        );
132        let fallback = format!("/prompt-learning/online/jobs/{}/events", self.job_id);
133        let value = self
134            .client
135            .get_json_fallback(
136                &[path.as_str(), fallback.as_str()],
137                AuthStyle::Both,
138            )
139            .await?;
140        parse_events(value)
141    }
142
143    pub async fn results(&self) -> Result<PromptLearningResults> {
144        let events = self.events().await?;
145        Ok(PromptLearningResults::from_events(&events))
146    }
147
148    pub async fn stream_events(&self) -> Result<SseStream> {
149        let primary = format!(
150            "{}/policy-optimization/online/jobs/{}/events/stream",
151            self.client.api_base(),
152            self.job_id
153        );
154        let fallback = format!(
155            "{}/prompt-learning/online/jobs/{}/events/stream",
156            self.client.api_base(),
157            self.job_id
158        );
159        let headers = self.client.auth_headers(AuthStyle::Both);
160        match stream_sse(self.client.http(), primary, headers.clone()).await {
161            Ok(stream) => Ok(stream),
162            Err(SynthError::Api { status: 404, .. }) => {
163                stream_sse(self.client.http(), fallback, headers).await
164            }
165            Err(err) => Err(err),
166        }
167    }
168}
169
170impl PromptLearningResults {
171    pub fn from_events(events: &[Value]) -> Self {
172        let mut results = PromptLearningResults::default();
173        let mut validation_by_rank: HashMap<i64, f64> = HashMap::new();
174
175        for event in events {
176            let event_type = event.get("type").and_then(|v| v.as_str()).unwrap_or("");
177            let data = event.get("data").and_then(|v| v.as_object());
178            if data.is_none() {
179                continue;
180            }
181            let data = data.unwrap();
182
183            match event_type {
184                "learning.policy.gepa.candidate.new_best" => {
185                    results.best_prompt = data.get("best_prompt").cloned();
186                    if results.best_score.is_none() {
187                        results.best_score = extract_reward_value(data, &["best_score"]);
188                    }
189                }
190                "learning.policy.gepa.candidate.evaluated" => {
191                    if let Some(rank) = data.get("rank").and_then(|v| v.as_i64()) {
192                        let mut prompt_entry = Map::new();
193                        prompt_entry.insert("rank".to_string(), json!(rank));
194                        prompt_entry.insert(
195                            "train_accuracy".to_string(),
196                            data.get("train_accuracy").cloned().unwrap_or(Value::Null),
197                        );
198                        prompt_entry.insert(
199                            "val_accuracy".to_string(),
200                            data.get("val_accuracy").cloned().unwrap_or(Value::Null),
201                        );
202                        if let Some(pattern) = data.get("pattern") {
203                            prompt_entry.insert("pattern".to_string(), pattern.clone());
204                            if let Some(text) = extract_full_text_from_pattern(pattern) {
205                                prompt_entry.insert("full_text".to_string(), json!(text));
206                            }
207                        } else if let Some(template) = data.get("template") {
208                            if let Some(pattern) = convert_template_to_pattern(template) {
209                                prompt_entry.insert("pattern".to_string(), pattern.clone());
210                                if let Some(text) = extract_full_text_from_pattern(&pattern) {
211                                    prompt_entry.insert("full_text".to_string(), json!(text));
212                                }
213                            }
214                        }
215                        results.top_prompts.push(Value::Object(prompt_entry));
216                    }
217                }
218                "learning.policy.gepa.job.completed" => {
219                    if let Some(cands) = data.get("optimized_candidates").and_then(|v| v.as_array())
220                    {
221                        results.optimized_candidates = cands.clone();
222                    }
223                    if let Some(cands) = data.get("attempted_candidates").and_then(|v| v.as_array())
224                    {
225                        results.attempted_candidates = cands.clone();
226                    }
227                    if results.best_prompt.is_none() {
228                        results.best_prompt = data.get("best_prompt").cloned();
229                    }
230                    if results.best_score.is_none() {
231                        results.best_score = extract_reward_value(data, &["best_score"]);
232                    }
233
234                    if let Some(validation) = data.get("validation").and_then(|v| v.as_array()) {
235                        for val in validation {
236                            if let Some(val_obj) = val.as_object() {
237                                if let (Some(rank), Some(score)) = (
238                                    val_obj.get("rank").and_then(|v| v.as_i64()),
239                                    extract_reward_value(val_obj, &[]),
240                                ) {
241                                    validation_by_rank.insert(rank, score);
242                                }
243                            }
244                        }
245                    }
246                }
247                "learning.policy.gepa.validation.completed" => {
248                    results.validation_results.push(Value::Object(data.clone()));
249                    if let (Some(rank), Some(score)) = (
250                        data.get("rank").and_then(|v| v.as_i64()),
251                        extract_reward_value(data, &[]),
252                    ) {
253                        validation_by_rank.insert(rank, score);
254                    }
255                }
256                "learning.policy.mipro.job.completed" => {
257                    if results.best_score.is_none() {
258                        results.best_score = extract_reward_value(
259                            data,
260                            &["best_score", "best_full_score", "best_minibatch_score"],
261                        );
262                    }
263                }
264                _ => {}
265            }
266        }
267
268        if results.top_prompts.is_empty() && !results.optimized_candidates.is_empty() {
269            for (idx, cand) in results.optimized_candidates.iter().enumerate() {
270                let cand_obj = match cand.as_object() {
271                    Some(obj) => obj,
272                    None => continue,
273                };
274                let rank = cand_obj
275                    .get("rank")
276                    .and_then(|v| v.as_i64())
277                    .unwrap_or((idx + 1) as i64);
278                let mut prompt_entry = Map::new();
279                prompt_entry.insert("rank".to_string(), json!(rank));
280
281                let train_accuracy = cand_obj
282                    .get("score")
283                    .and_then(|v| v.as_object())
284                    .and_then(|v| extract_reward_value(v, &[]))
285                    .or_else(|| extract_reward_value(cand_obj, &[]));
286                if let Some(score) = train_accuracy {
287                    prompt_entry.insert("train_accuracy".to_string(), json!(score));
288                }
289                if let Some(val) = validation_by_rank.get(&rank) {
290                    prompt_entry.insert("val_accuracy".to_string(), json!(*val));
291                }
292
293                if let Some(pattern) = cand_obj.get("pattern") {
294                    prompt_entry.insert("pattern".to_string(), pattern.clone());
295                    if let Some(text) = extract_full_text_from_pattern(pattern) {
296                        prompt_entry.insert("full_text".to_string(), json!(text));
297                    }
298                } else if let Some(template) = cand_obj.get("template") {
299                    if let Some(pattern) = convert_template_to_pattern(template) {
300                        if let Some(text) = extract_full_text_from_pattern(&pattern) {
301                            prompt_entry.insert("full_text".to_string(), json!(text));
302                        }
303                        prompt_entry.insert("pattern".to_string(), pattern);
304                    }
305                }
306
307                results.top_prompts.push(Value::Object(prompt_entry));
308            }
309        }
310
311        results
312    }
313}
314
315fn parse_events(value: Value) -> Result<Vec<Value>> {
316    if let Value::Array(items) = value {
317        return Ok(items);
318    }
319    if let Value::Object(obj) = value {
320        if let Some(Value::Array(items)) = obj.get("events") {
321            return Ok(items.clone());
322        }
323    }
324    Err(SynthError::UnexpectedResponse(
325        "events response did not contain an events list".to_string(),
326    ))
327}
328
329fn coerce_f64(value: &Value) -> Option<f64> {
330    match value {
331        Value::Number(num) => num.as_f64(),
332        Value::String(s) => s.parse::<f64>().ok(),
333        _ => None,
334    }
335}
336
337fn extract_outcome_reward(payload: &Map<String, Value>) -> Option<f64> {
338    if let Some(Value::Object(obj)) = payload.get("outcome_objectives") {
339        if let Some(val) = obj.get("reward").and_then(coerce_f64) {
340            return Some(val);
341        }
342    }
343    payload.get("outcome_reward").and_then(coerce_f64)
344}
345
346fn extract_reward_value(payload: &Map<String, Value>, fallback_keys: &[&str]) -> Option<f64> {
347    if let Some(val) = extract_outcome_reward(payload) {
348        return Some(val);
349    }
350    for key in fallback_keys {
351        if let Some(val) = payload.get(*key).and_then(coerce_f64) {
352            return Some(val);
353        }
354    }
355    None
356}
357
358fn convert_template_to_pattern(template: &Value) -> Option<Value> {
359    let sections = template
360        .get("sections")
361        .and_then(|v| v.as_array())
362        .filter(|v| !v.is_empty())
363        .or_else(|| template.get("prompt_sections").and_then(|v| v.as_array()))?;
364    let mut messages = Vec::new();
365    for sec in sections {
366        let sec_obj = sec.as_object()?;
367        let content = sec_obj.get("content")?;
368        if content.is_null() {
369            continue;
370        }
371        let role = sec_obj
372            .get("role")
373            .and_then(|v| v.as_str())
374            .or_else(|| sec_obj.get("name").and_then(|v| v.as_str()))
375            .unwrap_or("system");
376        let name = sec_obj.get("name").and_then(|v| v.as_str()).unwrap_or("");
377        messages.push(json!({
378            "role": role,
379            "name": name,
380            "pattern": content,
381        }));
382    }
383    if messages.is_empty() {
384        return None;
385    }
386    Some(json!({ "messages": messages }))
387}
388
389fn extract_full_text_from_pattern(pattern: &Value) -> Option<String> {
390    let messages = pattern.get("messages")?.as_array()?;
391    let mut parts = Vec::new();
392    for msg in messages {
393        let msg_obj = msg.as_object()?;
394        let role = msg_obj.get("role").and_then(|v| v.as_str()).unwrap_or("");
395        let name = msg_obj.get("name").and_then(|v| v.as_str()).unwrap_or("");
396        let content = msg_obj
397            .get("pattern")
398            .or_else(|| msg_obj.get("content"))
399            .and_then(|v| v.as_str())
400            .unwrap_or("");
401        parts.push(format!("[{role} | {name}]\n{content}"));
402    }
403    if parts.is_empty() {
404        None
405    } else {
406        Some(parts.join("\n\n"))
407    }
408}