1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
/// Text plot.
pub struct TextPlot {
    train: Vec<(f32, f32)>,
    valid: Vec<(f32, f32)>,
    max_values: usize,
    iteration: usize,
}

impl Default for TextPlot {
    fn default() -> Self {
        Self::new()
    }
}

impl TextPlot {
    /// Creates a new text plot.
    pub fn new() -> Self {
        Self {
            train: Vec::new(),
            valid: Vec::new(),
            max_values: 10000,
            iteration: 0,
        }
    }

    /// Merges two text plots.
    ///
    /// # Arguments
    ///
    /// * `self` - The first text plot.
    /// * `other` - The second text plot.
    ///
    /// # Returns
    ///
    /// The merged text plot.
    pub fn merge(self, other: Self) -> Self {
        let mut other = other;
        let mut train = self.train;
        let mut valid = self.valid;

        train.append(&mut other.train);
        valid.append(&mut other.valid);

        Self {
            train,
            valid,
            max_values: usize::min(self.max_values, other.max_values),
            iteration: self.iteration + other.iteration,
        }
    }

    /// Updates the text plot with a new item for training.
    ///
    /// # Arguments
    ///
    /// * `item` - The new item.
    pub fn update_train(&mut self, item: f32) {
        self.iteration += 1;
        self.train.push((self.iteration as f32, item));

        let x_max = self
            .train
            .last()
            .map(|(iteration, _)| *iteration)
            .unwrap_or(f32::MIN);
        let x_min = self
            .train
            .first()
            .map(|(iteration, _)| *iteration)
            .unwrap_or(f32::MAX);

        if x_max - x_min > self.max_values as f32 && !self.train.is_empty() {
            self.train.remove(0);
        }
    }

    /// Updates the text plot with a new item for validation.
    ///
    /// # Arguments
    ///
    /// * `item` - The new item.
    pub fn update_valid(&mut self, item: f32) {
        self.iteration += 1;
        self.valid.push((self.iteration as f32, item));

        let x_max = self
            .valid
            .last()
            .map(|(iteration, _)| *iteration)
            .unwrap_or(f32::MIN);
        let x_min = self
            .valid
            .first()
            .map(|(iteration, _)| *iteration)
            .unwrap_or(f32::MAX);

        if x_max - x_min > self.max_values as f32 && !self.valid.is_empty() {
            self.valid.remove(0);
        }
    }

    /// Renders the text plot.
    ///
    /// # Returns
    ///
    /// The rendered text plot.
    #[cfg(feature = "ui")]
    pub fn render(&self) -> String {
        use rgb::RGB8;
        use terminal_size::{terminal_size, Height, Width};
        use textplots::{Chart, ColorPlot, Shape};

        let train_color = RGB8::new(255, 140, 140);
        let valid_color = RGB8::new(140, 140, 255);

        let x_max_valid = self
            .valid
            .last()
            .map(|(iteration, _)| *iteration)
            .unwrap_or(f32::MIN);
        let x_max_train = self
            .train
            .last()
            .map(|(iteration, _)| *iteration)
            .unwrap_or(f32::MIN);
        let x_max = f32::max(x_max_train, x_max_valid);

        let x_min_valid = self
            .valid
            .first()
            .map(|(iteration, _)| *iteration)
            .unwrap_or(f32::MAX);
        let x_min_train = self
            .train
            .first()
            .map(|(iteration, _)| *iteration)
            .unwrap_or(f32::MAX);
        let x_min = f32::min(x_min_train, x_min_valid);

        let (width, height) = match terminal_size() {
            Some((Width(w), Height(_))) => (u32::max(64, w.into()) * 2 - 16, 64),
            None => (256, 64),
        };

        Chart::new(width, height, x_min, x_max)
            .linecolorplot(&Shape::Lines(&self.train), train_color)
            .linecolorplot(&Shape::Lines(&self.valid), valid_color)
            .to_string()
    }

    /// Renders the text plot.
    ///
    /// # Returns
    ///
    /// The rendered text plot.
    #[cfg(not(feature = "ui"))]
    pub fn render(&self) -> String {
        panic!("ui feature not enabled on burn-train")
    }
}