1use ratatui::{
10 layout::{Constraint, Direction, Layout, Rect},
11 style::Style,
12 symbols,
13 text::{Line, Span},
14 widgets::{Axis, Block, Borders, Chart, Dataset, GraphType, Paragraph},
15 Frame,
16};
17
18use crate::theme::{AxonmlTheme, INFO, TEAL, TERRACOTTA};
19
20pub type DataPoint = (f64, f64);
26
27#[derive(Debug, Clone)]
29#[allow(dead_code)]
30pub struct DataSeries {
31 pub name: String,
32 pub data: Vec<DataPoint>,
33 pub color: ratatui::style::Color,
34 pub marker: symbols::Marker,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum ChartType {
40 Loss,
41 Accuracy,
42 LearningRate,
43}
44
45impl ChartType {
46 fn as_str(&self) -> &'static str {
47 match self {
48 ChartType::Loss => "Loss",
49 ChartType::Accuracy => "Accuracy",
50 ChartType::LearningRate => "Learning Rate",
51 }
52 }
53}
54
55pub struct GraphsView {
61 pub train_loss: Vec<DataPoint>,
63
64 pub val_loss: Vec<DataPoint>,
66
67 pub train_acc: Vec<DataPoint>,
69
70 pub val_acc: Vec<DataPoint>,
72
73 pub learning_rate: Vec<DataPoint>,
75
76 pub active_chart: ChartType,
78
79 pub x_bounds: [f64; 2],
81
82 pub loss_bounds: [f64; 2],
84
85 pub acc_bounds: [f64; 2],
87}
88
89impl GraphsView {
90 pub fn new() -> Self {
92 let mut view = Self {
93 train_loss: Vec::new(),
94 val_loss: Vec::new(),
95 train_acc: Vec::new(),
96 val_acc: Vec::new(),
97 learning_rate: Vec::new(),
98 active_chart: ChartType::Loss,
99 x_bounds: [0.0, 20.0],
100 loss_bounds: [0.0, 2.5],
101 acc_bounds: [0.0, 100.0],
102 };
103
104 view.load_demo_data();
105 view
106 }
107
108 pub fn load_demo_data(&mut self) {
110 self.train_loss = vec![
112 (1.0, 2.31), (2.0, 1.85), (3.0, 1.23), (4.0, 0.86), (5.0, 0.61),
113 (6.0, 0.48), (7.0, 0.39), (8.0, 0.32), (9.0, 0.27), (10.0, 0.23),
114 (11.0, 0.20), (12.0, 0.17), (13.0, 0.15), (14.0, 0.13), (15.0, 0.12),
115 ];
116
117 self.val_loss = vec![
118 (1.0, 2.30), (2.0, 1.76), (3.0, 1.19), (4.0, 0.82), (5.0, 0.60),
119 (6.0, 0.49), (7.0, 0.41), (8.0, 0.36), (9.0, 0.32), (10.0, 0.29),
120 (11.0, 0.27), (12.0, 0.25), (13.0, 0.24), (14.0, 0.23), (15.0, 0.22),
121 ];
122
123 self.train_acc = vec![
124 (1.0, 11.2), (2.0, 34.2), (3.0, 56.7), (4.0, 71.2), (5.0, 79.8),
125 (6.0, 84.5), (7.0, 87.8), (8.0, 90.2), (9.0, 91.8), (10.0, 93.1),
126 (11.0, 94.0), (12.0, 94.7), (13.0, 95.2), (14.0, 95.6), (15.0, 95.9),
127 ];
128
129 self.val_acc = vec![
130 (1.0, 11.8), (2.0, 35.8), (3.0, 58.2), (4.0, 72.4), (5.0, 80.5),
131 (6.0, 84.2), (7.0, 86.9), (8.0, 89.1), (9.0, 90.5), (10.0, 91.4),
132 (11.0, 92.1), (12.0, 92.6), (13.0, 93.0), (14.0, 93.3), (15.0, 93.5),
133 ];
134
135 self.learning_rate = vec![
136 (1.0, 1.0), (2.0, 1.0), (3.0, 1.0), (4.0, 1.0), (5.0, 0.5),
137 (6.0, 0.5), (7.0, 0.5), (8.0, 0.25), (9.0, 0.25), (10.0, 0.25),
138 (11.0, 0.125), (12.0, 0.125), (13.0, 0.125), (14.0, 0.0625), (15.0, 0.0625),
139 ];
140
141 self.x_bounds = [0.0, 16.0];
142 }
143
144 pub fn next_chart(&mut self) {
146 self.active_chart = match self.active_chart {
147 ChartType::Loss => ChartType::Accuracy,
148 ChartType::Accuracy => ChartType::LearningRate,
149 ChartType::LearningRate => ChartType::Loss,
150 };
151 }
152
153 pub fn prev_chart(&mut self) {
155 self.active_chart = match self.active_chart {
156 ChartType::Loss => ChartType::LearningRate,
157 ChartType::Accuracy => ChartType::Loss,
158 ChartType::LearningRate => ChartType::Accuracy,
159 };
160 }
161
162 pub fn toggle_zoom(&mut self) {
164 }
166
167 pub fn render(&mut self, frame: &mut Frame, area: Rect) {
169 let chunks = Layout::default()
170 .direction(Direction::Vertical)
171 .constraints([
172 Constraint::Length(3), Constraint::Min(15), Constraint::Length(5), ])
176 .split(area);
177
178 self.render_selector(frame, chunks[0]);
179 self.render_chart(frame, chunks[1]);
180 self.render_legend(frame, chunks[2]);
181 }
182
183 fn render_selector(&self, frame: &mut Frame, area: Rect) {
184 let tabs: Vec<Span> = [ChartType::Loss, ChartType::Accuracy, ChartType::LearningRate]
185 .iter()
186 .map(|ct| {
187 let style = if *ct == self.active_chart {
188 AxonmlTheme::tab_active()
189 } else {
190 AxonmlTheme::tab_inactive()
191 };
192 Span::styled(format!(" {} ", ct.as_str()), style)
193 })
194 .collect();
195
196 let selector = Paragraph::new(Line::from(tabs))
197 .block(
198 Block::default()
199 .borders(Borders::ALL)
200 .border_style(AxonmlTheme::border())
201 .title(Span::styled(" Chart Type (</> to switch) ", AxonmlTheme::header())),
202 );
203
204 frame.render_widget(selector, area);
205 }
206
207 fn render_chart(&self, frame: &mut Frame, area: Rect) {
208 match self.active_chart {
209 ChartType::Loss => self.render_loss_chart(frame, area),
210 ChartType::Accuracy => self.render_accuracy_chart(frame, area),
211 ChartType::LearningRate => self.render_lr_chart(frame, area),
212 }
213 }
214
215 fn render_loss_chart(&self, frame: &mut Frame, area: Rect) {
216 let datasets = vec![
217 Dataset::default()
218 .name("Train Loss")
219 .marker(symbols::Marker::Braille)
220 .graph_type(GraphType::Line)
221 .style(Style::default().fg(TEAL))
222 .data(&self.train_loss),
223 Dataset::default()
224 .name("Val Loss")
225 .marker(symbols::Marker::Braille)
226 .graph_type(GraphType::Line)
227 .style(Style::default().fg(TERRACOTTA))
228 .data(&self.val_loss),
229 ];
230
231 let chart = Chart::new(datasets)
232 .block(
233 Block::default()
234 .borders(Borders::ALL)
235 .border_style(AxonmlTheme::border_focused())
236 .title(Span::styled(" Loss Curves ", AxonmlTheme::header())),
237 )
238 .x_axis(
239 Axis::default()
240 .title(Span::styled("Epoch", AxonmlTheme::graph_label()))
241 .style(AxonmlTheme::graph_axis())
242 .bounds(self.x_bounds)
243 .labels(vec![
244 Span::raw("0"),
245 Span::raw("5"),
246 Span::raw("10"),
247 Span::raw("15"),
248 ]),
249 )
250 .y_axis(
251 Axis::default()
252 .title(Span::styled("Loss", AxonmlTheme::graph_label()))
253 .style(AxonmlTheme::graph_axis())
254 .bounds(self.loss_bounds)
255 .labels(vec![
256 Span::raw("0.0"),
257 Span::raw("1.0"),
258 Span::raw("2.0"),
259 ]),
260 );
261
262 frame.render_widget(chart, area);
263 }
264
265 fn render_accuracy_chart(&self, frame: &mut Frame, area: Rect) {
266 let datasets = vec![
267 Dataset::default()
268 .name("Train Acc")
269 .marker(symbols::Marker::Braille)
270 .graph_type(GraphType::Line)
271 .style(Style::default().fg(TEAL))
272 .data(&self.train_acc),
273 Dataset::default()
274 .name("Val Acc")
275 .marker(symbols::Marker::Braille)
276 .graph_type(GraphType::Line)
277 .style(Style::default().fg(TERRACOTTA))
278 .data(&self.val_acc),
279 ];
280
281 let chart = Chart::new(datasets)
282 .block(
283 Block::default()
284 .borders(Borders::ALL)
285 .border_style(AxonmlTheme::border_focused())
286 .title(Span::styled(" Accuracy Curves ", AxonmlTheme::header())),
287 )
288 .x_axis(
289 Axis::default()
290 .title(Span::styled("Epoch", AxonmlTheme::graph_label()))
291 .style(AxonmlTheme::graph_axis())
292 .bounds(self.x_bounds)
293 .labels(vec![
294 Span::raw("0"),
295 Span::raw("5"),
296 Span::raw("10"),
297 Span::raw("15"),
298 ]),
299 )
300 .y_axis(
301 Axis::default()
302 .title(Span::styled("Accuracy %", AxonmlTheme::graph_label()))
303 .style(AxonmlTheme::graph_axis())
304 .bounds(self.acc_bounds)
305 .labels(vec![
306 Span::raw("0"),
307 Span::raw("50"),
308 Span::raw("100"),
309 ]),
310 );
311
312 frame.render_widget(chart, area);
313 }
314
315 fn render_lr_chart(&self, frame: &mut Frame, area: Rect) {
316 let datasets = vec![
317 Dataset::default()
318 .name("Learning Rate")
319 .marker(symbols::Marker::Braille)
320 .graph_type(GraphType::Line)
321 .style(Style::default().fg(INFO))
322 .data(&self.learning_rate),
323 ];
324
325 let chart = Chart::new(datasets)
326 .block(
327 Block::default()
328 .borders(Borders::ALL)
329 .border_style(AxonmlTheme::border_focused())
330 .title(Span::styled(" Learning Rate Schedule ", AxonmlTheme::header())),
331 )
332 .x_axis(
333 Axis::default()
334 .title(Span::styled("Epoch", AxonmlTheme::graph_label()))
335 .style(AxonmlTheme::graph_axis())
336 .bounds(self.x_bounds)
337 .labels(vec![
338 Span::raw("0"),
339 Span::raw("5"),
340 Span::raw("10"),
341 Span::raw("15"),
342 ]),
343 )
344 .y_axis(
345 Axis::default()
346 .title(Span::styled("LR (relative)", AxonmlTheme::graph_label()))
347 .style(AxonmlTheme::graph_axis())
348 .bounds([0.0, 1.2])
349 .labels(vec![
350 Span::raw("0"),
351 Span::raw("0.5"),
352 Span::raw("1.0"),
353 ]),
354 );
355
356 frame.render_widget(chart, area);
357 }
358
359 fn render_legend(&self, frame: &mut Frame, area: Rect) {
360 let legend_text = match self.active_chart {
361 ChartType::Loss => vec![
362 Line::from(vec![
363 Span::styled("\u{2588}\u{2588}", Style::default().fg(TEAL)),
364 Span::styled(" Train Loss", AxonmlTheme::graph_label()),
365 Span::raw(" "),
366 Span::styled("\u{2588}\u{2588}", Style::default().fg(TERRACOTTA)),
367 Span::styled(" Val Loss", AxonmlTheme::graph_label()),
368 ]),
369 Line::from(vec![
370 Span::styled("Latest: ", AxonmlTheme::muted()),
371 Span::styled(
372 format!("Train {:.4}", self.train_loss.last().map(|p| p.1).unwrap_or(0.0)),
373 AxonmlTheme::metric_value(),
374 ),
375 Span::raw(" "),
376 Span::styled(
377 format!("Val {:.4}", self.val_loss.last().map(|p| p.1).unwrap_or(0.0)),
378 AxonmlTheme::accent(),
379 ),
380 ]),
381 ],
382 ChartType::Accuracy => vec![
383 Line::from(vec![
384 Span::styled("\u{2588}\u{2588}", Style::default().fg(TEAL)),
385 Span::styled(" Train Acc", AxonmlTheme::graph_label()),
386 Span::raw(" "),
387 Span::styled("\u{2588}\u{2588}", Style::default().fg(TERRACOTTA)),
388 Span::styled(" Val Acc", AxonmlTheme::graph_label()),
389 ]),
390 Line::from(vec![
391 Span::styled("Latest: ", AxonmlTheme::muted()),
392 Span::styled(
393 format!("Train {:.1}%", self.train_acc.last().map(|p| p.1).unwrap_or(0.0)),
394 AxonmlTheme::success(),
395 ),
396 Span::raw(" "),
397 Span::styled(
398 format!("Val {:.1}%", self.val_acc.last().map(|p| p.1).unwrap_or(0.0)),
399 AxonmlTheme::success(),
400 ),
401 ]),
402 ],
403 ChartType::LearningRate => vec![
404 Line::from(vec![
405 Span::styled("\u{2588}\u{2588}", Style::default().fg(INFO)),
406 Span::styled(" Learning Rate (normalized)", AxonmlTheme::graph_label()),
407 ]),
408 Line::from(vec![
409 Span::styled("Current: ", AxonmlTheme::muted()),
410 Span::styled(
411 format!("{:.4}x initial", self.learning_rate.last().map(|p| p.1).unwrap_or(1.0)),
412 AxonmlTheme::metric_value(),
413 ),
414 ]),
415 ],
416 };
417
418 let legend = Paragraph::new(legend_text)
419 .block(
420 Block::default()
421 .borders(Borders::ALL)
422 .border_style(AxonmlTheme::border())
423 .title(Span::styled(" Legend ", AxonmlTheme::header())),
424 );
425
426 frame.render_widget(legend, area);
427 }
428}
429
430impl Default for GraphsView {
431 fn default() -> Self {
432 Self::new()
433 }
434}