1use std::path::Path;
9
10use ratatui::{
11 layout::{Constraint, Direction, Layout, Rect},
12 style::Style,
13 text::{Line, Span},
14 widgets::{Block, Borders, List, ListItem, ListState, Paragraph, Scrollbar, ScrollbarOrientation, ScrollbarState},
15 Frame,
16};
17
18use crate::theme::AxonmlTheme;
19use axonml_serialize::load_state_dict;
20
21#[derive(Debug, Clone)]
27pub struct LayerInfo {
28 pub name: String,
29 pub layer_type: String,
30 pub input_shape: String,
31 pub output_shape: String,
32 pub params: usize,
33 pub trainable: bool,
34}
35
36#[derive(Debug, Clone)]
38pub struct ModelInfo {
39 pub name: String,
40 pub layers: Vec<LayerInfo>,
41 pub total_params: usize,
42 pub trainable_params: usize,
43 pub file_size: u64,
44 pub format: String,
45}
46
47pub struct ModelView {
53 pub model: Option<ModelInfo>,
55
56 pub selected_layer: usize,
58
59 pub list_state: ListState,
61
62 pub scroll_state: ScrollbarState,
64
65 pub show_details: bool,
67}
68
69impl ModelView {
70 pub fn new() -> Self {
72 let mut list_state = ListState::default();
73 list_state.select(Some(0));
74
75 Self {
76 model: None,
77 selected_layer: 0,
78 list_state,
79 scroll_state: ScrollbarState::default(),
80 show_details: false,
81 }
82 }
83
84 pub fn load_model(&mut self, path: &Path) -> Result<(), String> {
86 let model = self.parse_model_file(path)?;
89 let layer_count = model.layers.len();
90 self.model = Some(model);
91 self.selected_layer = 0;
92 self.list_state.select(Some(0));
93 self.scroll_state = ScrollbarState::default().content_length(layer_count);
94 Ok(())
95 }
96
97 fn parse_model_file(&self, path: &Path) -> Result<ModelInfo, String> {
99 let file_name = path
100 .file_name()
101 .and_then(|n| n.to_str())
102 .unwrap_or("model");
103
104 let file_size = std::fs::metadata(path)
106 .map(|m| m.len())
107 .unwrap_or(0);
108
109 let format = path.extension()
111 .and_then(|ext| ext.to_str())
112 .map(|ext| match ext.to_lowercase().as_str() {
113 "axonml" => "Axonml",
114 "safetensors" => "SafeTensors",
115 "json" => "JSON",
116 "pt" | "pth" => "PyTorch",
117 "onnx" => "ONNX",
118 _ => "Unknown",
119 })
120 .unwrap_or("Unknown")
121 .to_string();
122
123 let state_dict = load_state_dict(path)
125 .map_err(|e| format!("Failed to load model: {}", e))?;
126
127 let mut layer_map: std::collections::BTreeMap<String, Vec<(String, Vec<usize>, usize, bool)>> =
129 std::collections::BTreeMap::new();
130
131 for (param_name, entry) in state_dict.entries() {
132 let shape = entry.data.shape.clone();
133 let num_params: usize = shape.iter().product();
134
135 let layer_name = if let Some(idx) = param_name.rfind('.') {
137 param_name[..idx].to_string()
138 } else {
139 param_name.clone()
140 };
141
142 layer_map
143 .entry(layer_name)
144 .or_default()
145 .push((param_name.clone(), shape, num_params, entry.requires_grad));
146 }
147
148 let mut layers = Vec::new();
150 for (layer_name, params) in layer_map {
151 let layer_num_params: usize = params.iter().map(|(_, _, p, _)| *p).sum();
152 let trainable = params.iter().any(|(_, _, _, t)| *t);
153
154 let layer_type = infer_layer_type(¶ms);
156
157 let (input_shape, output_shape) = infer_shapes(¶ms);
159
160 layers.push(LayerInfo {
161 name: layer_name,
162 layer_type,
163 input_shape,
164 output_shape,
165 params: layer_num_params,
166 trainable,
167 });
168 }
169
170 let total_params: usize = layers.iter().map(|l| l.params).sum();
171 let trainable_params: usize = layers
172 .iter()
173 .filter(|l| l.trainable)
174 .map(|l| l.params)
175 .sum();
176
177 Ok(ModelInfo {
178 name: file_name.to_string(),
179 layers,
180 total_params,
181 trainable_params,
182 file_size,
183 format,
184 })
185 }
186
187 pub fn select_prev(&mut self) {
189 if self.model.is_some() && self.selected_layer > 0 {
190 self.selected_layer -= 1;
191 self.list_state.select(Some(self.selected_layer));
192 self.scroll_state = self.scroll_state.position(self.selected_layer);
193 }
194 }
195
196 pub fn select_next(&mut self) {
198 if let Some(model) = &self.model {
199 if self.selected_layer < model.layers.len() - 1 {
200 self.selected_layer += 1;
201 self.list_state.select(Some(self.selected_layer));
202 self.scroll_state = self.scroll_state.position(self.selected_layer);
203 }
204 }
205 }
206
207 pub fn toggle_details(&mut self) {
209 self.show_details = !self.show_details;
210 }
211
212 pub fn render(&mut self, frame: &mut Frame, area: Rect) {
214 if let Some(model) = self.model.clone() {
215 let chunks = Layout::default()
216 .direction(Direction::Vertical)
217 .constraints([
218 Constraint::Length(5), Constraint::Min(10), Constraint::Length(8), ])
222 .split(area);
223
224 self.render_header(frame, chunks[0], &model);
225 self.render_layers(frame, chunks[1], &model);
226 self.render_details(frame, chunks[2], &model);
227 } else {
228 self.render_empty(frame, area);
229 }
230 }
231
232 fn render_header(&self, frame: &mut Frame, area: Rect, model: &ModelInfo) {
233 let header_text = vec![
234 Line::from(vec![
235 Span::styled("Model: ", AxonmlTheme::muted()),
236 Span::styled(&model.name, AxonmlTheme::title()),
237 ]),
238 Line::from(vec![
239 Span::styled("Format: ", AxonmlTheme::muted()),
240 Span::styled(&model.format, AxonmlTheme::accent()),
241 Span::raw(" "),
242 Span::styled("Size: ", AxonmlTheme::muted()),
243 Span::styled(format_size(model.file_size), AxonmlTheme::accent()),
244 ]),
245 Line::from(vec![
246 Span::styled("Total Params: ", AxonmlTheme::muted()),
247 Span::styled(format_number(model.total_params), AxonmlTheme::metric_value()),
248 Span::raw(" "),
249 Span::styled("Trainable: ", AxonmlTheme::muted()),
250 Span::styled(format_number(model.trainable_params), AxonmlTheme::success()),
251 ]),
252 ];
253
254 let header = Paragraph::new(header_text)
255 .block(
256 Block::default()
257 .borders(Borders::ALL)
258 .border_style(AxonmlTheme::border())
259 .title(Span::styled(" Model Info ", AxonmlTheme::header())),
260 );
261
262 frame.render_widget(header, area);
263 }
264
265 fn render_layers(&mut self, frame: &mut Frame, area: Rect, model: &ModelInfo) {
266 let items: Vec<ListItem> = model
267 .layers
268 .iter()
269 .enumerate()
270 .map(|(i, layer)| {
271 let style = if i == self.selected_layer {
272 AxonmlTheme::selected()
273 } else {
274 Style::default()
275 };
276
277 let content = Line::from(vec![
278 Span::styled(
279 format!("{:>2}. ", i + 1),
280 AxonmlTheme::muted(),
281 ),
282 Span::styled(
283 format!("{:<12}", layer.name),
284 AxonmlTheme::layer_type(),
285 ),
286 Span::styled(
287 format!("{:<15}", layer.layer_type),
288 AxonmlTheme::accent(),
289 ),
290 Span::styled(
291 format!("{:>15}", layer.output_shape),
292 AxonmlTheme::layer_shape(),
293 ),
294 Span::styled(
295 format!("{:>12}", format_number(layer.params)),
296 if layer.trainable {
297 AxonmlTheme::success()
298 } else {
299 AxonmlTheme::muted()
300 },
301 ),
302 ]);
303
304 ListItem::new(content).style(style)
305 })
306 .collect();
307
308 let list = List::new(items)
309 .block(
310 Block::default()
311 .borders(Borders::ALL)
312 .border_style(AxonmlTheme::border_focused())
313 .title(Span::styled(" Layers ", AxonmlTheme::header())),
314 )
315 .highlight_style(AxonmlTheme::selected());
316
317 frame.render_stateful_widget(list, area, &mut self.list_state);
318
319 let scrollbar = Scrollbar::new(ScrollbarOrientation::VerticalRight)
321 .begin_symbol(Some("▲"))
322 .end_symbol(Some("▼"));
323
324 frame.render_stateful_widget(
325 scrollbar,
326 area.inner(ratatui::layout::Margin { vertical: 1, horizontal: 0 }),
327 &mut self.scroll_state,
328 );
329 }
330
331 fn render_details(&self, frame: &mut Frame, area: Rect, model: &ModelInfo) {
332 let layer = &model.layers[self.selected_layer];
333
334 let details = vec![
335 Line::from(vec![
336 Span::styled("Layer: ", AxonmlTheme::muted()),
337 Span::styled(&layer.name, AxonmlTheme::title()),
338 Span::styled(" (", AxonmlTheme::muted()),
339 Span::styled(&layer.layer_type, AxonmlTheme::accent()),
340 Span::styled(")", AxonmlTheme::muted()),
341 ]),
342 Line::from(vec![
343 Span::styled("Input: ", AxonmlTheme::muted()),
344 Span::styled(&layer.input_shape, AxonmlTheme::layer_shape()),
345 ]),
346 Line::from(vec![
347 Span::styled("Output: ", AxonmlTheme::muted()),
348 Span::styled(&layer.output_shape, AxonmlTheme::layer_shape()),
349 ]),
350 Line::from(vec![
351 Span::styled("Params: ", AxonmlTheme::muted()),
352 Span::styled(format_number(layer.params), AxonmlTheme::metric_value()),
353 Span::raw(" "),
354 Span::styled("Trainable: ", AxonmlTheme::muted()),
355 Span::styled(
356 if layer.trainable { "Yes" } else { "No" },
357 if layer.trainable {
358 AxonmlTheme::success()
359 } else {
360 AxonmlTheme::muted()
361 },
362 ),
363 ]),
364 ];
365
366 let details_widget = Paragraph::new(details)
367 .block(
368 Block::default()
369 .borders(Borders::ALL)
370 .border_style(AxonmlTheme::border())
371 .title(Span::styled(" Layer Details ", AxonmlTheme::header())),
372 );
373
374 frame.render_widget(details_widget, area);
375 }
376
377 fn render_empty(&self, frame: &mut Frame, area: Rect) {
378 let text = vec![
379 Line::from(""),
380 Line::from(Span::styled(
381 "No model loaded",
382 AxonmlTheme::muted(),
383 )),
384 Line::from(""),
385 Line::from(Span::styled(
386 "Press 'o' to open a model file",
387 AxonmlTheme::info(),
388 )),
389 Line::from(Span::styled(
390 "or use: axonml tui --model <path>",
391 AxonmlTheme::muted(),
392 )),
393 ];
394
395 let paragraph = Paragraph::new(text)
396 .block(
397 Block::default()
398 .borders(Borders::ALL)
399 .border_style(AxonmlTheme::border())
400 .title(Span::styled(" Model Architecture ", AxonmlTheme::header())),
401 )
402 .alignment(ratatui::layout::Alignment::Center);
403
404 frame.render_widget(paragraph, area);
405 }
406}
407
408impl Default for ModelView {
409 fn default() -> Self {
410 Self::new()
411 }
412}
413
414fn format_size(bytes: u64) -> String {
419 const KB: u64 = 1024;
420 const MB: u64 = KB * 1024;
421 const GB: u64 = MB * 1024;
422
423 if bytes >= GB {
424 format!("{:.2} GB", bytes as f64 / GB as f64)
425 } else if bytes >= MB {
426 format!("{:.2} MB", bytes as f64 / MB as f64)
427 } else if bytes >= KB {
428 format!("{:.2} KB", bytes as f64 / KB as f64)
429 } else {
430 format!("{} B", bytes)
431 }
432}
433
434fn format_number(n: usize) -> String {
435 if n >= 1_000_000 {
436 format!("{:.2}M", n as f64 / 1_000_000.0)
437 } else if n >= 1_000 {
438 format!("{:.2}K", n as f64 / 1_000.0)
439 } else {
440 n.to_string()
441 }
442}
443
444fn infer_layer_type(params: &[(String, Vec<usize>, usize, bool)]) -> String {
446 for (name, shape, _, _) in params {
447 if name.ends_with(".weight") {
448 let dims = shape.len();
449 if dims == 4 {
450 return "Conv2d".to_string();
451 } else if dims == 2 {
452 return "Linear".to_string();
453 } else if dims == 1 {
454 return "BatchNorm".to_string();
455 }
456 }
457 if name.ends_with(".gamma") || name.ends_with(".beta") {
458 return "LayerNorm".to_string();
459 }
460 if name.ends_with(".embedding") {
461 return "Embedding".to_string();
462 }
463 }
464 "Unknown".to_string()
465}
466
467fn infer_shapes(params: &[(String, Vec<usize>, usize, bool)]) -> (String, String) {
469 for (name, shape, _, _) in params {
470 if name.ends_with(".weight") && shape.len() >= 2 {
471 if shape.len() == 2 {
474 return (
475 format!("[batch, {}]", shape[1]),
476 format!("[batch, {}]", shape[0]),
477 );
478 } else if shape.len() == 4 {
479 return (
480 format!("[batch, {}, H, W]", shape[1]),
481 format!("[batch, {}, H', W']", shape[0]),
482 );
483 }
484 }
485 }
486 ("-".to_string(), "-".to_string())
487}