1use once_cell::sync::OnceCell;
2use parking_lot::Mutex;
3use reqwest::blocking::Client as Http;
4use reqwest::StatusCode;
5use serde::Serialize;
6use std::env;
7use std::time::Duration;
8
9#[derive(Debug)]
11pub struct Client {
12 base_url: String,
13 project: String,
14 run: String,
15 write_token: Option<String>,
16
17 http: Http,
18 cached_bulk_path: OnceCell<String>,
19
20 buf: Mutex<Vec<LogItem>>,
22 max_batch: usize,
23 #[allow(dead_code)]
24 flush_interval: Duration,
25}
26
27#[derive(Debug, Clone, Serialize)]
28struct BulkPayload<'a> {
29 project: &'a str,
30 run: &'a str,
31 #[serde(rename = "metrics_list")]
32 metrics_list: Vec<serde_json::Value>,
33 steps: Vec<i64>,
34 timestamps: Vec<String>,
35 #[serde(skip_serializing_if = "Option::is_none")]
36 config: Option<serde_json::Value>,
37}
38
39#[derive(Debug, Clone)]
40pub struct LogItem {
41 pub metrics: serde_json::Value,
42 pub step: Option<i64>,
43 pub timestamp: Option<String>,
44}
45
46impl Client {
47 pub fn new() -> Self {
58 let base = env::var("TRACKIO_SERVER_URL").unwrap_or_else(|_| "http://127.0.0.1:7860".into());
59 let project = env::var("TRACKIO_PROJECT").unwrap_or_default();
60 let run = env::var("TRACKIO_RUN").unwrap_or_default();
61 let write_token = env::var("TRACKIO_WRITE_TOKEN").ok();
62
63 let timeout_ms = env::var("TRACKIO_TIMEOUT_MS")
64 .ok()
65 .and_then(|s| s.parse::<u64>().ok())
66 .unwrap_or(5000);
67
68 let max_batch = env::var("TRACKIO_MAX_BATCH")
69 .ok()
70 .and_then(|s| s.parse::<usize>().ok())
71 .unwrap_or(128);
72
73 let flush_interval = env::var("TRACKIO_FLUSH_INTERVAL_MS")
74 .ok()
75 .and_then(|s| s.parse::<u64>().ok())
76 .map(Duration::from_millis)
77 .unwrap_or(Duration::from_millis(200));
78
79 Self {
80 base_url: base,
81 project,
82 run,
83 write_token,
84 http: Http::builder()
85 .timeout(Duration::from_millis(timeout_ms))
86 .build()
87 .expect("failed to build HTTP client"),
88 cached_bulk_path: OnceCell::new(),
89 buf: Mutex::new(Vec::with_capacity(max_batch)),
90 max_batch,
91 flush_interval,
92 }
93 }
94
95 pub fn with_project(mut self, p: &str) -> Self {
96 self.project = p.into();
97 self
98 }
99
100 pub fn with_run(mut self, r: &str) -> Self {
101 self.run = r.into();
102 self
103 }
104
105 pub fn with_base_url(mut self, u: &str) -> Self {
106 self.base_url = u.into();
107 self
108 }
109
110 pub fn log(&self, metrics: serde_json::Value, step: Option<i64>, ts: Option<String>) {
113 let mut buf = self.buf.lock();
114 buf.push(LogItem {
115 metrics,
116 step,
117 timestamp: ts,
118 });
119 if buf.len() >= self.max_batch {
120 drop(buf);
121 let _ = self.flush(); }
123 }
124
125 pub fn flush(&self) -> Result<(), TrackioError> {
127 let items = {
128 let mut buf = self.buf.lock();
129 if buf.is_empty() {
130 return Ok(());
131 }
132 let out = buf.clone();
133 buf.clear();
134 out
135 };
136
137 let mut metrics_list = Vec::with_capacity(items.len());
138 let mut steps = Vec::with_capacity(items.len());
139 let mut timestamps = Vec::with_capacity(items.len());
140
141 for it in items {
142 metrics_list.push(it.metrics);
143 steps.push(it.step.unwrap_or(-1));
144 timestamps.push(it.timestamp.unwrap_or_else(|| "".into()));
145 }
146
147 let payload = BulkPayload {
148 project: &self.project,
149 run: &self.run,
150 metrics_list,
151 steps,
152 timestamps,
153 config: None,
154 };
155
156 let path = self.cached_bulk_path.get_or_try_init(|| {
158 if self.try_post("/api/bulk_log", &payload).is_ok() {
159 return Ok("/api/bulk_log".to_string());
160 }
161 if self.try_post("/gradio_api/bulk_log", &payload).is_ok() {
162 return Ok("/gradio_api/bulk_log".to_string());
163 }
164 Err(TrackioError::NoBulkEndpoint)
165 })?;
166
167 self.try_post(path, &payload)
168 }
169
170 fn try_post<P: AsRef<str>, T: Serialize>(
172 &self,
173 path: P,
174 payload: &T,
175 ) -> Result<(), TrackioError> {
176 let url = format!("{}{}", self.base_url, path.as_ref());
177 let mut req = self.http.post(url).json(payload);
178 if let Some(tok) = &self.write_token {
179 req = req.header("X-Trackio-Write-Token", tok);
180 }
181 let resp = req.send().map_err(TrackioError::Http)?;
182 if !resp.status().is_success() {
183 let status = resp.status();
184 let body = resp.text().unwrap_or_default();
185 if status == StatusCode::NOT_FOUND {
186 return Err(TrackioError::NotFound(body));
187 }
188 return Err(TrackioError::Status(status.as_u16(), body));
189 }
190 Ok(())
191 }
192
193 pub fn close(&self) -> Result<(), TrackioError> {
195 self.flush()
196 }
197}
198
199#[derive(thiserror::Error, Debug)]
200pub enum TrackioError {
201 #[error("no Trackio bulk endpoint found")]
202 NoBulkEndpoint,
203 #[error("HTTP error: {0}")]
204 Http(#[from] reqwest::Error),
205 #[error("404 Not Found: {0}")]
206 NotFound(String),
207 #[error("HTTP {0}: {1}")]
208 Status(u16, String),
209}