Skip to main content

axonml_tui/views/
model.rs

1//! Model View - Display Neural Network Architecture
2//!
3//! Shows model layers, parameters, shapes, and structure.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use std::path::Path;
9
10use ratatui::{
11    layout::{Constraint, Direction, Layout, Rect},
12    style::Style,
13    text::{Line, Span},
14    widgets::{Block, Borders, List, ListItem, ListState, Paragraph, Scrollbar, ScrollbarOrientation, ScrollbarState},
15    Frame,
16};
17
18use crate::theme::AxonmlTheme;
19use axonml_serialize::load_state_dict;
20
21// =============================================================================
22// Types
23// =============================================================================
24
25/// Layer information for display
26#[derive(Debug, Clone)]
27pub struct LayerInfo {
28    pub name: String,
29    pub layer_type: String,
30    pub input_shape: String,
31    pub output_shape: String,
32    pub params: usize,
33    pub trainable: bool,
34}
35
36/// Model information
37#[derive(Debug, Clone)]
38pub struct ModelInfo {
39    pub name: String,
40    pub layers: Vec<LayerInfo>,
41    pub total_params: usize,
42    pub trainable_params: usize,
43    pub file_size: u64,
44    pub format: String,
45}
46
47// =============================================================================
48// Model View
49// =============================================================================
50
51/// Model architecture view state
52pub struct ModelView {
53    /// Loaded model info
54    pub model: Option<ModelInfo>,
55
56    /// Selected layer index
57    pub selected_layer: usize,
58
59    /// List state for layer navigation
60    pub list_state: ListState,
61
62    /// Scroll state
63    pub scroll_state: ScrollbarState,
64
65    /// Show detailed view
66    pub show_details: bool,
67}
68
69impl ModelView {
70    /// Create a new model view
71    pub fn new() -> Self {
72        let mut list_state = ListState::default();
73        list_state.select(Some(0));
74
75        Self {
76            model: None,
77            selected_layer: 0,
78            list_state,
79            scroll_state: ScrollbarState::default(),
80            show_details: false,
81        }
82    }
83
84    /// Load a model from file
85    pub fn load_model(&mut self, path: &Path) -> Result<(), String> {
86        // For now, create a demo model structure
87        // In real implementation, this would parse actual model files
88        let model = self.parse_model_file(path)?;
89        let layer_count = model.layers.len();
90        self.model = Some(model);
91        self.selected_layer = 0;
92        self.list_state.select(Some(0));
93        self.scroll_state = ScrollbarState::default().content_length(layer_count);
94        Ok(())
95    }
96
97    /// Parse model file using axonml-serialize
98    fn parse_model_file(&self, path: &Path) -> Result<ModelInfo, String> {
99        let file_name = path
100            .file_name()
101            .and_then(|n| n.to_str())
102            .unwrap_or("model");
103
104        // Get file size
105        let file_size = std::fs::metadata(path)
106            .map(|m| m.len())
107            .unwrap_or(0);
108
109        // Detect format from extension
110        let format = path.extension()
111            .and_then(|ext| ext.to_str())
112            .map(|ext| match ext.to_lowercase().as_str() {
113                "axonml" => "Axonml",
114                "safetensors" => "SafeTensors",
115                "json" => "JSON",
116                "pt" | "pth" => "PyTorch",
117                "onnx" => "ONNX",
118                _ => "Unknown",
119            })
120            .unwrap_or("Unknown")
121            .to_string();
122
123        // Load the actual state dict
124        let state_dict = load_state_dict(path)
125            .map_err(|e| format!("Failed to load model: {}", e))?;
126
127        // Group parameters by layer prefix and extract layer info
128        let mut layer_map: std::collections::BTreeMap<String, Vec<(String, Vec<usize>, usize, bool)>> =
129            std::collections::BTreeMap::new();
130
131        for (param_name, entry) in state_dict.entries() {
132            let shape = entry.data.shape.clone();
133            let num_params: usize = shape.iter().product();
134
135            // Extract layer name from parameter name (e.g., "layer1.conv.weight" -> "layer1.conv")
136            let layer_name = if let Some(idx) = param_name.rfind('.') {
137                param_name[..idx].to_string()
138            } else {
139                param_name.clone()
140            };
141
142            layer_map
143                .entry(layer_name)
144                .or_default()
145                .push((param_name.clone(), shape, num_params, entry.requires_grad));
146        }
147
148        // Create LayerInfo for each unique layer
149        let mut layers = Vec::new();
150        for (layer_name, params) in layer_map {
151            let layer_num_params: usize = params.iter().map(|(_, _, p, _)| *p).sum();
152            let trainable = params.iter().any(|(_, _, _, t)| *t);
153
154            // Infer layer type from parameter names
155            let layer_type = infer_layer_type(&params);
156
157            // Get shape from weight parameter if available
158            let (input_shape, output_shape) = infer_shapes(&params);
159
160            layers.push(LayerInfo {
161                name: layer_name,
162                layer_type,
163                input_shape,
164                output_shape,
165                params: layer_num_params,
166                trainable,
167            });
168        }
169
170        let total_params: usize = layers.iter().map(|l| l.params).sum();
171        let trainable_params: usize = layers
172            .iter()
173            .filter(|l| l.trainable)
174            .map(|l| l.params)
175            .sum();
176
177        Ok(ModelInfo {
178            name: file_name.to_string(),
179            layers,
180            total_params,
181            trainable_params,
182            file_size,
183            format,
184        })
185    }
186
187    /// Move selection up
188    pub fn select_prev(&mut self) {
189        if self.model.is_some() && self.selected_layer > 0 {
190            self.selected_layer -= 1;
191            self.list_state.select(Some(self.selected_layer));
192            self.scroll_state = self.scroll_state.position(self.selected_layer);
193        }
194    }
195
196    /// Move selection down
197    pub fn select_next(&mut self) {
198        if let Some(model) = &self.model {
199            if self.selected_layer < model.layers.len() - 1 {
200                self.selected_layer += 1;
201                self.list_state.select(Some(self.selected_layer));
202                self.scroll_state = self.scroll_state.position(self.selected_layer);
203            }
204        }
205    }
206
207    /// Toggle detailed view
208    pub fn toggle_details(&mut self) {
209        self.show_details = !self.show_details;
210    }
211
212    /// Render the model view
213    pub fn render(&mut self, frame: &mut Frame, area: Rect) {
214        if let Some(model) = self.model.clone() {
215            let chunks = Layout::default()
216                .direction(Direction::Vertical)
217                .constraints([
218                    Constraint::Length(5),  // Header
219                    Constraint::Min(10),    // Layers list
220                    Constraint::Length(8),  // Details panel
221                ])
222                .split(area);
223
224            self.render_header(frame, chunks[0], &model);
225            self.render_layers(frame, chunks[1], &model);
226            self.render_details(frame, chunks[2], &model);
227        } else {
228            self.render_empty(frame, area);
229        }
230    }
231
232    fn render_header(&self, frame: &mut Frame, area: Rect, model: &ModelInfo) {
233        let header_text = vec![
234            Line::from(vec![
235                Span::styled("Model: ", AxonmlTheme::muted()),
236                Span::styled(&model.name, AxonmlTheme::title()),
237            ]),
238            Line::from(vec![
239                Span::styled("Format: ", AxonmlTheme::muted()),
240                Span::styled(&model.format, AxonmlTheme::accent()),
241                Span::raw("  "),
242                Span::styled("Size: ", AxonmlTheme::muted()),
243                Span::styled(format_size(model.file_size), AxonmlTheme::accent()),
244            ]),
245            Line::from(vec![
246                Span::styled("Total Params: ", AxonmlTheme::muted()),
247                Span::styled(format_number(model.total_params), AxonmlTheme::metric_value()),
248                Span::raw("  "),
249                Span::styled("Trainable: ", AxonmlTheme::muted()),
250                Span::styled(format_number(model.trainable_params), AxonmlTheme::success()),
251            ]),
252        ];
253
254        let header = Paragraph::new(header_text)
255            .block(
256                Block::default()
257                    .borders(Borders::ALL)
258                    .border_style(AxonmlTheme::border())
259                    .title(Span::styled(" Model Info ", AxonmlTheme::header())),
260            );
261
262        frame.render_widget(header, area);
263    }
264
265    fn render_layers(&mut self, frame: &mut Frame, area: Rect, model: &ModelInfo) {
266        let items: Vec<ListItem> = model
267            .layers
268            .iter()
269            .enumerate()
270            .map(|(i, layer)| {
271                let style = if i == self.selected_layer {
272                    AxonmlTheme::selected()
273                } else {
274                    Style::default()
275                };
276
277                let content = Line::from(vec![
278                    Span::styled(
279                        format!("{:>2}. ", i + 1),
280                        AxonmlTheme::muted(),
281                    ),
282                    Span::styled(
283                        format!("{:<12}", layer.name),
284                        AxonmlTheme::layer_type(),
285                    ),
286                    Span::styled(
287                        format!("{:<15}", layer.layer_type),
288                        AxonmlTheme::accent(),
289                    ),
290                    Span::styled(
291                        format!("{:>15}", layer.output_shape),
292                        AxonmlTheme::layer_shape(),
293                    ),
294                    Span::styled(
295                        format!("{:>12}", format_number(layer.params)),
296                        if layer.trainable {
297                            AxonmlTheme::success()
298                        } else {
299                            AxonmlTheme::muted()
300                        },
301                    ),
302                ]);
303
304                ListItem::new(content).style(style)
305            })
306            .collect();
307
308        let list = List::new(items)
309            .block(
310                Block::default()
311                    .borders(Borders::ALL)
312                    .border_style(AxonmlTheme::border_focused())
313                    .title(Span::styled(" Layers ", AxonmlTheme::header())),
314            )
315            .highlight_style(AxonmlTheme::selected());
316
317        frame.render_stateful_widget(list, area, &mut self.list_state);
318
319        // Render scrollbar
320        let scrollbar = Scrollbar::new(ScrollbarOrientation::VerticalRight)
321            .begin_symbol(Some("▲"))
322            .end_symbol(Some("▼"));
323
324        frame.render_stateful_widget(
325            scrollbar,
326            area.inner(ratatui::layout::Margin { vertical: 1, horizontal: 0 }),
327            &mut self.scroll_state,
328        );
329    }
330
331    fn render_details(&self, frame: &mut Frame, area: Rect, model: &ModelInfo) {
332        let layer = &model.layers[self.selected_layer];
333
334        let details = vec![
335            Line::from(vec![
336                Span::styled("Layer: ", AxonmlTheme::muted()),
337                Span::styled(&layer.name, AxonmlTheme::title()),
338                Span::styled(" (", AxonmlTheme::muted()),
339                Span::styled(&layer.layer_type, AxonmlTheme::accent()),
340                Span::styled(")", AxonmlTheme::muted()),
341            ]),
342            Line::from(vec![
343                Span::styled("Input:  ", AxonmlTheme::muted()),
344                Span::styled(&layer.input_shape, AxonmlTheme::layer_shape()),
345            ]),
346            Line::from(vec![
347                Span::styled("Output: ", AxonmlTheme::muted()),
348                Span::styled(&layer.output_shape, AxonmlTheme::layer_shape()),
349            ]),
350            Line::from(vec![
351                Span::styled("Params: ", AxonmlTheme::muted()),
352                Span::styled(format_number(layer.params), AxonmlTheme::metric_value()),
353                Span::raw("  "),
354                Span::styled("Trainable: ", AxonmlTheme::muted()),
355                Span::styled(
356                    if layer.trainable { "Yes" } else { "No" },
357                    if layer.trainable {
358                        AxonmlTheme::success()
359                    } else {
360                        AxonmlTheme::muted()
361                    },
362                ),
363            ]),
364        ];
365
366        let details_widget = Paragraph::new(details)
367            .block(
368                Block::default()
369                    .borders(Borders::ALL)
370                    .border_style(AxonmlTheme::border())
371                    .title(Span::styled(" Layer Details ", AxonmlTheme::header())),
372            );
373
374        frame.render_widget(details_widget, area);
375    }
376
377    fn render_empty(&self, frame: &mut Frame, area: Rect) {
378        let text = vec![
379            Line::from(""),
380            Line::from(Span::styled(
381                "No model loaded",
382                AxonmlTheme::muted(),
383            )),
384            Line::from(""),
385            Line::from(Span::styled(
386                "Press 'o' to open a model file",
387                AxonmlTheme::info(),
388            )),
389            Line::from(Span::styled(
390                "or use: axonml tui --model <path>",
391                AxonmlTheme::muted(),
392            )),
393        ];
394
395        let paragraph = Paragraph::new(text)
396            .block(
397                Block::default()
398                    .borders(Borders::ALL)
399                    .border_style(AxonmlTheme::border())
400                    .title(Span::styled(" Model Architecture ", AxonmlTheme::header())),
401            )
402            .alignment(ratatui::layout::Alignment::Center);
403
404        frame.render_widget(paragraph, area);
405    }
406}
407
408impl Default for ModelView {
409    fn default() -> Self {
410        Self::new()
411    }
412}
413
414// =============================================================================
415// Helpers
416// =============================================================================
417
418fn format_size(bytes: u64) -> String {
419    const KB: u64 = 1024;
420    const MB: u64 = KB * 1024;
421    const GB: u64 = MB * 1024;
422
423    if bytes >= GB {
424        format!("{:.2} GB", bytes as f64 / GB as f64)
425    } else if bytes >= MB {
426        format!("{:.2} MB", bytes as f64 / MB as f64)
427    } else if bytes >= KB {
428        format!("{:.2} KB", bytes as f64 / KB as f64)
429    } else {
430        format!("{} B", bytes)
431    }
432}
433
434fn format_number(n: usize) -> String {
435    if n >= 1_000_000 {
436        format!("{:.2}M", n as f64 / 1_000_000.0)
437    } else if n >= 1_000 {
438        format!("{:.2}K", n as f64 / 1_000.0)
439    } else {
440        n.to_string()
441    }
442}
443
444/// Infer layer type from parameter names
445fn infer_layer_type(params: &[(String, Vec<usize>, usize, bool)]) -> String {
446    for (name, shape, _, _) in params {
447        if name.ends_with(".weight") {
448            let dims = shape.len();
449            if dims == 4 {
450                return "Conv2d".to_string();
451            } else if dims == 2 {
452                return "Linear".to_string();
453            } else if dims == 1 {
454                return "BatchNorm".to_string();
455            }
456        }
457        if name.ends_with(".gamma") || name.ends_with(".beta") {
458            return "LayerNorm".to_string();
459        }
460        if name.ends_with(".embedding") {
461            return "Embedding".to_string();
462        }
463    }
464    "Unknown".to_string()
465}
466
467/// Infer input and output shapes from parameters
468fn infer_shapes(params: &[(String, Vec<usize>, usize, bool)]) -> (String, String) {
469    for (name, shape, _, _) in params {
470        if name.ends_with(".weight") && shape.len() >= 2 {
471            // For Linear: [out_features, in_features]
472            // For Conv2d: [out_channels, in_channels, kernel_h, kernel_w]
473            if shape.len() == 2 {
474                return (
475                    format!("[batch, {}]", shape[1]),
476                    format!("[batch, {}]", shape[0]),
477                );
478            } else if shape.len() == 4 {
479                return (
480                    format!("[batch, {}, H, W]", shape[1]),
481                    format!("[batch, {}, H', W']", shape[0]),
482                );
483            }
484        }
485    }
486    ("-".to_string(), "-".to_string())
487}