Skip to main content

synth_ai/
localapi.rs

1use std::collections::HashSet;
2use std::future::Future;
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6use axum::extract::State;
7use axum::http::{HeaderMap, StatusCode};
8use axum::response::{IntoResponse, Response};
9use axum::routing::{get, post};
10use axum::{Json, Router};
11use futures_util::future::BoxFuture;
12use serde::{Deserialize, Serialize};
13use serde_json::{json, Map, Value};
14
15use crate::types::{Result, SynthError};
16
17pub type RolloutHandler =
18    Arc<dyn Fn(RolloutRequest) -> BoxFuture<'static, std::result::Result<RolloutResponse, LocalApiError>>
19        + Send
20        + Sync>;
21
22#[derive(Debug, Clone)]
23pub struct LocalApiError {
24    pub status: StatusCode,
25    pub message: String,
26}
27
28impl LocalApiError {
29    pub fn bad_request(message: impl Into<String>) -> Self {
30        Self {
31            status: StatusCode::BAD_REQUEST,
32            message: message.into(),
33        }
34    }
35
36    pub fn internal(message: impl Into<String>) -> Self {
37        Self {
38            status: StatusCode::INTERNAL_SERVER_ERROR,
39            message: message.into(),
40        }
41    }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct TaskDescriptor {
46    pub id: String,
47    pub name: String,
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub description: Option<String>,
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub version: Option<String>,
52    #[serde(flatten)]
53    pub extra: Map<String, Value>,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, Default)]
57pub struct DatasetInfo {
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub id: Option<String>,
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub name: Option<String>,
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub version: Option<String>,
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub splits: Option<Vec<String>>,
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub default_split: Option<String>,
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub description: Option<String>,
70    #[serde(flatten)]
71    pub extra: Map<String, Value>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize, Default)]
75pub struct InferenceInfo {
76    #[serde(skip_serializing_if = "Option::is_none")]
77    pub model: Option<String>,
78    #[serde(skip_serializing_if = "Option::is_none")]
79    pub inference_url: Option<String>,
80    #[serde(flatten)]
81    pub extra: Map<String, Value>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize, Default)]
85pub struct LimitsInfo {
86    #[serde(skip_serializing_if = "Option::is_none")]
87    pub max_turns: Option<i64>,
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub max_response_tokens: Option<i64>,
90    #[serde(skip_serializing_if = "Option::is_none")]
91    pub timeout_seconds: Option<i64>,
92    #[serde(flatten)]
93    pub extra: Map<String, Value>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct TaskInfo {
98    pub task: TaskDescriptor,
99    pub dataset: DatasetInfo,
100    pub inference: InferenceInfo,
101    pub limits: LimitsInfo,
102    #[serde(skip_serializing_if = "Option::is_none")]
103    pub task_metadata: Option<Value>,
104    #[serde(flatten)]
105    pub extra: Map<String, Value>,
106}
107
108impl TaskInfo {
109    pub fn minimal(app_id: impl Into<String>, name: impl Into<String>, description: impl Into<String>) -> Self {
110        let task = TaskDescriptor {
111            id: app_id.into(),
112            name: name.into(),
113            description: Some(description.into()),
114            version: None,
115            extra: Map::new(),
116        };
117        Self {
118            task,
119            dataset: DatasetInfo::default(),
120            inference: InferenceInfo::default(),
121            limits: LimitsInfo::default(),
122            task_metadata: None,
123            extra: Map::new(),
124        }
125    }
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct RolloutRequest {
130    pub trace_correlation_id: String,
131    pub env: Value,
132    pub policy: Value,
133    #[serde(skip_serializing_if = "Option::is_none")]
134    pub on_done: Option<String>,
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub safety: Option<Value>,
137    #[serde(skip_serializing_if = "Option::is_none")]
138    pub training_session_id: Option<String>,
139    #[serde(skip_serializing_if = "Option::is_none")]
140    pub synth_base_url: Option<String>,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub context_overrides: Option<Value>,
143    #[serde(skip_serializing_if = "Option::is_none")]
144    pub override_bundle_id: Option<String>,
145    #[serde(flatten)]
146    pub extra: Map<String, Value>,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct RolloutMetrics {
151    pub outcome_reward: f64,
152    #[serde(skip_serializing_if = "Option::is_none")]
153    pub event_rewards: Option<Vec<f64>>,
154    #[serde(skip_serializing_if = "Option::is_none")]
155    pub outcome_objectives: Option<Map<String, Value>>,
156    #[serde(skip_serializing_if = "Option::is_none")]
157    pub event_objectives: Option<Vec<Map<String, Value>>>,
158    #[serde(skip_serializing_if = "Option::is_none")]
159    pub instance_objectives: Option<Vec<Map<String, Value>>>,
160    #[serde(default, skip_serializing_if = "Map::is_empty")]
161    pub details: Map<String, Value>,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct RolloutResponse {
166    pub trace_correlation_id: String,
167    pub reward_info: RolloutMetrics,
168    #[serde(skip_serializing_if = "Option::is_none")]
169    pub trace: Option<Value>,
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub inference_url: Option<String>,
172    #[serde(skip_serializing_if = "Option::is_none")]
173    pub artifact: Option<Value>,
174    #[serde(skip_serializing_if = "Option::is_none")]
175    pub success_status: Option<String>,
176    #[serde(skip_serializing_if = "Option::is_none")]
177    pub status_detail: Option<String>,
178    #[serde(skip_serializing_if = "Option::is_none")]
179    pub override_application_results: Option<Value>,
180    #[serde(flatten)]
181    pub extra: Map<String, Value>,
182}
183
184#[derive(Clone)]
185pub struct LocalApiConfig {
186    pub task_info: TaskInfo,
187    pub rollout: RolloutHandler,
188    pub require_api_key: bool,
189    pub api_keys: Vec<String>,
190}
191
192impl LocalApiConfig {
193    pub fn new<F, Fut>(
194        app_id: impl Into<String>,
195        name: impl Into<String>,
196        description: impl Into<String>,
197        handler: F,
198    ) -> Self
199    where
200        F: Fn(RolloutRequest) -> Fut + Send + Sync + 'static,
201        Fut: Future<Output = std::result::Result<RolloutResponse, LocalApiError>> + Send + 'static,
202    {
203        let rollout: RolloutHandler = Arc::new(move |req| Box::pin(handler(req)));
204        let mut api_keys = Vec::new();
205        if let Ok(val) = std::env::var("ENVIRONMENT_API_KEY") {
206            api_keys.push(val);
207        }
208        Self {
209            task_info: TaskInfo::minimal(app_id, name, description),
210            rollout,
211            require_api_key: true,
212            api_keys,
213        }
214    }
215}
216
217#[derive(Clone)]
218pub struct LocalApiApp {
219    router: Router,
220}
221
222pub fn create_local_api(config: LocalApiConfig) -> LocalApiApp {
223    let state = Arc::new(config);
224
225    let router = Router::new()
226        .route("/", get(root))
227        .route("/health", get(health))
228        .route("/task_info", get(task_info))
229        .route("/info", get(info))
230        .route("/rollout", post(rollout))
231        .with_state(state);
232
233    LocalApiApp { router }
234}
235
236impl LocalApiApp {
237    pub fn router(&self) -> Router {
238        self.router.clone()
239    }
240
241    pub async fn run(self, addr: SocketAddr) -> Result<()> {
242        axum::Server::bind(&addr)
243            .serve(self.router.into_make_service())
244            .await
245            .map_err(|err| SynthError::UnexpectedResponse(err.to_string()))
246    }
247}
248
249async fn root() -> Response {
250    Json(json!({"status": "ok"})).into_response()
251}
252
253async fn health(
254    State(config): State<Arc<LocalApiConfig>>,
255    headers: HeaderMap,
256) -> Response {
257    if let Err(resp) = authorize(&config, &headers) {
258        return resp;
259    }
260    Json(json!({ "healthy": true })).into_response()
261}
262
263async fn task_info(
264    State(config): State<Arc<LocalApiConfig>>,
265    headers: HeaderMap,
266) -> Response {
267    if let Err(resp) = authorize(&config, &headers) {
268        return resp;
269    }
270    Json(config.task_info.clone()).into_response()
271}
272
273async fn info(
274    State(config): State<Arc<LocalApiConfig>>,
275    headers: HeaderMap,
276) -> Response {
277    if let Err(resp) = authorize(&config, &headers) {
278        return resp;
279    }
280    let task = config.task_info.task.clone();
281    let version = task.version.clone();
282    let service = json!({
283        "task": task,
284        "version": version,
285    });
286    let payload = json!({
287        "service": service,
288        "dataset": config.task_info.dataset,
289        "rubrics": null,
290        "inference": config.task_info.inference,
291        "limits": config.task_info.limits,
292    });
293    Json(payload).into_response()
294}
295
296async fn rollout(
297    State(config): State<Arc<LocalApiConfig>>,
298    headers: HeaderMap,
299    Json(request): Json<RolloutRequest>,
300) -> impl IntoResponse {
301    if let Err(resp) = authorize(&config, &headers) {
302        return resp;
303    }
304    let handler = config.rollout.clone();
305    match handler(request).await {
306        Ok(resp) => (StatusCode::OK, Json(resp)).into_response(),
307        Err(err) => (
308            err.status,
309            Json(json!({ "error": err.message })),
310        )
311            .into_response(),
312    }
313}
314
315fn authorize(config: &LocalApiConfig, headers: &HeaderMap) -> std::result::Result<(), axum::response::Response> {
316    if !config.require_api_key {
317        return Ok(());
318    }
319    let allowed = api_key_set(config);
320    if allowed.is_empty() {
321        let resp = (
322            StatusCode::SERVICE_UNAVAILABLE,
323            Json(json!({ "error": "ENVIRONMENT_API_KEY is not configured" })),
324        )
325            .into_response();
326        return Err(resp);
327    }
328    let provided = header_keys(headers);
329    if provided.iter().any(|key| allowed.contains(key)) {
330        return Ok(());
331    }
332    let resp = (
333        StatusCode::UNAUTHORIZED,
334        Json(json!({ "error": "API key missing or invalid" })),
335    )
336        .into_response();
337    Err(resp)
338}
339
340fn api_key_set(config: &LocalApiConfig) -> HashSet<String> {
341    let mut set = HashSet::new();
342    for key in &config.api_keys {
343        if !key.is_empty() {
344            set.insert(key.clone());
345        }
346    }
347    if let Ok(aliases) = std::env::var("ENVIRONMENT_API_KEY_ALIASES") {
348        for part in aliases.split(',') {
349            let trimmed = part.trim();
350            if !trimmed.is_empty() {
351                set.insert(trimmed.to_string());
352            }
353        }
354    }
355    set
356}
357
358fn header_keys(headers: &HeaderMap) -> Vec<String> {
359    let mut keys = Vec::new();
360    for header in ["x-api-key", "x-api-keys", "authorization"] {
361        if let Some(value) = headers.get(header) {
362            if let Ok(text) = value.to_str() {
363                if header == "authorization" && text.to_lowercase().starts_with("bearer ") {
364                    keys.extend(split_keys(&text[7..]));
365                } else {
366                    keys.extend(split_keys(text));
367                }
368            }
369        }
370    }
371    keys
372}
373
374fn split_keys(input: &str) -> Vec<String> {
375    input
376        .split(',')
377        .map(|s| s.trim().to_string())
378        .filter(|s| !s.is_empty())
379        .collect()
380}