nncombinator/
persistence.rs

1//! Implementation on persistence of neural network models
2
3use std::fmt::Display;
4use std::fs::{File, OpenOptions};
5use std::io;
6use std::io::{BufRead, BufReader, BufWriter, Read, Write};
7use std::path::Path;
8use std::str::FromStr;
9use crate::error::*;
10
11pub trait Persistence<U,P,K> where K: PersistenceType {
12    /// Load Model
13    /// # Arguments
14    /// * `persistence` - model persistent object
15    ///
16    /// # Errors
17    ///
18    /// This function may return the following errors
19    /// * [`ConfigReadError`]
20    fn load(&mut self, persistence:&mut P) -> Result<(),ConfigReadError>;
21    /// Save Model
22    /// # Arguments
23    /// * `persistence` - model persistent object
24    ///
25    /// # Errors
26    ///
27    /// This function may return the following errors
28    /// * [`PersistenceError`]
29    fn save(&mut self, persistence:&mut P) -> Result<(), PersistenceError>;
30}
31pub trait PersistenceType {}
32pub struct Specialized;
33pub struct Linear;
34impl PersistenceType for Specialized {}
35impl PersistenceType for Linear {}
36
37/// Trait that defines the implementation of the ability to save a model to a file
38pub trait SaveToFile<U> {
39    /// Save to File
40    /// # Arguments
41    /// * `file` - Destination path
42    ///
43    /// # Errors
44    ///
45    /// This function may return the following errors
46    /// * [`io::Error`]
47    fn save<P: AsRef<Path>>(&self,file:P) -> Result<(),io::Error>;
48}
49/// Trait to define an implementation to persist the model in a flat data structure
50pub trait LinearPersistence<U> {
51    /// Read to restore the persisted model
52    ///
53    /// # Errors
54    ///
55    /// This function may return the following errors
56    /// * [`ConfigReadError`]
57    fn read(&mut self) -> Result<U, ConfigReadError>;
58    /// Write to persist model information
59    /// # Arguments
60    /// * `u` - Weight value
61    ///
62    /// # Errors
63    ///
64    /// This function may return the following errors
65    /// * [`PersistenceError`]
66    fn write(&mut self, u:U) -> Result<(), PersistenceError>;
67    /// Has the read position of the persisted information reached EOF?
68    ///
69    /// # Errors
70    ///
71    /// This function may return the following errors
72    /// * [`ConfigReadError`]
73    fn verify_eof(&mut self) -> Result<(),ConfigReadError>;
74}
75/// Types for passing identifiable information about layers and unit boundaries when persisting models
76pub enum UnitOrMarker<U> {
77    /// Not a boundary.
78    Unit(U),
79    /// layer boundary
80    LayerStart,
81    /// boundary
82    UnitsStart
83}
84/// Persistent object for saving to a text file
85pub struct TextFilePersistence<U> where U: FromStr + Sized {
86    reader:Option<BufReader<File>>,
87    line:Option<Vec<String>>,
88    index:usize,
89    data:Vec<UnitOrMarker<U>>
90}
91impl<U> TextFilePersistence<U> where U: FromStr + Sized {
92    /// Create an instance of TextFilePersistence
93    /// # Arguments
94    /// * `file` - File path to be persisted
95    ///
96    /// # Errors
97    ///
98    /// This function may return the following errors
99    /// * [`ConfigReadError`]
100    pub fn new<P: AsRef<Path>>(file:P) -> Result<TextFilePersistence<U>,ConfigReadError> {
101        if file.as_ref().exists() {
102            Ok(TextFilePersistence {
103                reader:Some(BufReader::new(OpenOptions::new().read(true).create(false).open(file)?)),
104                line: None,
105                index: 0usize,
106                data: Vec::new()
107            })
108        } else {
109            Ok(TextFilePersistence {
110                reader:None,
111                line: None,
112                index: 0usize,
113                data: Vec::new()
114            })
115        }
116    }
117
118    fn read_line(&mut self) -> Result<String, ConfigReadError> {
119        match self.reader {
120            Some(ref mut reader) => {
121                let mut buf = String::new();
122                let n = reader.read_line(&mut buf)?;
123
124                buf = buf.trim().to_string();
125
126                if n == 0 {
127                    Err(ConfigReadError::InvalidState(String::from(
128                        "End of input has been reached.")))
129                } else {
130                    Ok(buf)
131                }
132            },
133            None => {
134                Err(ConfigReadError::InvalidState(String::from(
135                    "File does not exist yet.")))
136            }
137        }
138    }
139
140    fn next_token(&mut self) -> Result<String, ConfigReadError> {
141        let t = match self.line {
142            None => {
143                self.index = 0;
144                let mut buf = self.read_line()?;
145
146                while match &*buf {
147                    "" => true,
148                    s => match s.chars().nth(0) {
149                        Some('#') => true,
150                        _ => false,
151                    }
152                } {
153                    buf = self.read_line()?;
154                }
155
156                let line = buf.split(" ").map(|s| s.to_string()).collect::<Vec<String>>();
157                let t = (&line[self.index]).clone();
158                self.line = Some(line);
159                t
160            },
161            Some(ref line) => {
162                (&line[self.index]).clone()
163            }
164        };
165
166        self.index = self.index + 1;
167
168        if match self.line {
169            Some(ref line) if self.index >= line.len() => {
170                true
171            },
172            Some(_) => {
173                false
174            }
175            None => false,
176        } {
177            self.line = None;
178        }
179
180        Ok(t)
181    }
182
183    /// Has the read position of the persisted information reached EOF?
184    ///
185    /// # Errors
186    ///
187    /// This function may return the following errors
188    /// * [`ConfigReadError`]
189    pub fn verify_eof(&mut self) -> Result<(),ConfigReadError> {
190        match self.reader {
191            Some(ref mut reader) => {
192                let mut buf = String::new();
193
194                loop {
195                    let n = reader.read_line(&mut buf)?;
196
197                    if n == 0 {
198                        return Ok(());
199                    }
200
201                    buf = buf.trim().to_string();
202
203                    if !buf.is_empty() {
204                        return Err(ConfigReadError::InvalidState(
205                            String::from("Data loaded , but the input has not reached the end.")));
206                    } else {
207                        buf.clear();
208                    }
209                }
210            },
211            None => {
212                Err(ConfigReadError::InvalidState(String::from(
213                    "File does not exist yet.")))
214            }
215        }
216    }
217}
218impl<U> TextFilePersistence<U> where U: FromStr + Sized, ConfigReadError: From<<U as FromStr>::Err> {
219    /// Read the weight values from a file
220    ///
221    /// # Errors
222    ///
223    /// This function may return the following errors
224    /// * [`ConfigReadError`]
225    pub fn read(&mut self) -> Result<U, ConfigReadError> {
226        Ok(self.next_token()?.parse::<U>()?)
227    }
228}
229impl<U> TextFilePersistence<U> where U: FromStr + Sized {
230    /// Layer weights are added to the end of the internal buffer
231    /// # Arguments
232    /// * `v` - Weight value
233    pub fn write(&mut self,v:UnitOrMarker<U>) {
234        self.data.push(v);
235    }
236}
237impl<U> SaveToFile<U> for TextFilePersistence<U> where U: FromStr + Sized + Display {
238    fn save<P: AsRef<Path>>(&self,file:P) -> Result<(),io::Error> {
239        let mut bw = BufWriter::new(OpenOptions::new().write(true).create(true).open(file)?);
240
241        for u in self.data.iter() {
242            match u {
243                UnitOrMarker::Unit(u) => {
244                    bw.write(format!("{} ",u).as_bytes())?;
245                },
246                UnitOrMarker::LayerStart => {
247                    bw.write(b"#layer\n")?;
248                },
249                UnitOrMarker::UnitsStart => {
250                    bw.write(b"\n")?;
251                }
252            }
253        }
254
255        Ok(())
256    }
257}
258/// Trait that defines a Persistence implementation
259/// that stores and loads in fixed length record format.
260pub struct BinFilePersistence<U> {
261    reader:Option<BufReader<File>>,
262    data:Vec<U>
263}
264impl<U> BinFilePersistence<U> {
265    /// Create an instance of TextFilePersistence
266    /// # Arguments
267    /// * `file` - File path to be persisted
268    ///
269    /// # Errors
270    ///
271    /// This function may return the following errors
272    /// * [`ConfigReadError`]
273    pub fn new<P: AsRef<Path>>(file:P) -> Result<BinFilePersistence<U>, ConfigReadError> {
274        if file.as_ref().exists() {
275            Ok(BinFilePersistence {
276                reader:Some(BufReader::new(OpenOptions::new().read(true).create(false).open(file)?)),
277                data:Vec::new()
278            })
279        } else {
280            Ok(BinFilePersistence {
281                reader:None,
282                data:Vec::new()
283            })
284        }
285    }
286}
287impl LinearPersistence<f64> for BinFilePersistence<f64> {
288    fn read(&mut self) -> Result<f64, ConfigReadError> {
289        match self.reader {
290            Some(ref mut reader) => {
291                let mut buf = [0; 8];
292
293                reader.read_exact(&mut buf)?;
294
295                Ok(f64::from_bits(
296                    (buf[0] as u64) << 56 |
297                        (buf[1] as u64) << 48 |
298                        (buf[2] as u64) << 40 |
299                        (buf[3] as u64) << 32 |
300                        (buf[4] as u64) << 24 |
301                        (buf[5] as u64) << 16 |
302                        (buf[6] as u64) << 8 |
303                        buf[7] as u64)
304                )
305            },
306            None => {
307                Err(ConfigReadError::InvalidState(String::from(
308                    "File does not exist yet.")))
309            }
310        }
311    }
312
313    fn write(&mut self, u: f64) -> Result<(), PersistenceError> {
314        self.data.push(u);
315        Ok(())
316    }
317
318    fn verify_eof(&mut self) -> Result<(), ConfigReadError> {
319        match self.reader {
320            Some(ref mut reader) => {
321                let mut buf: [u8; 1] = [0];
322
323                let n = reader.read(&mut buf)?;
324
325                if n == 0 {
326                    Ok(())
327                } else {
328                    Err(ConfigReadError::InvalidState(String::from("Data loaded , but the input has not reached the end.")))
329                }
330            },
331            None => {
332                Err(ConfigReadError::InvalidState(String::from(
333                    "File does not exist yet.")))
334            }
335        }
336    }
337}
338impl LinearPersistence<f32> for BinFilePersistence<f32> {
339    fn read(&mut self) -> Result<f32, ConfigReadError> {
340        match self.reader {
341            Some(ref mut reader) => {
342                let mut buf = [0; 4];
343
344                reader.read_exact(&mut buf)?;
345
346                Ok(f32::from_bits(
347                    (buf[0] as u32) << 24 |
348                        (buf[1] as u32) << 16 |
349                        (buf[2] as u32) << 8 |
350                        buf[3] as u32)
351                )
352            },
353            None => {
354                Err(ConfigReadError::InvalidState(String::from(
355                    "File does not exist yet.")))
356            }
357        }
358    }
359
360    fn write(&mut self, u: f32) -> Result<(), PersistenceError> {
361        self.data.push(u);
362        Ok(())
363    }
364
365    fn verify_eof(&mut self) -> Result<(), ConfigReadError> {
366        match self.reader {
367            Some(ref mut reader) => {
368                let mut buf: [u8; 1] = [0];
369
370                let n = reader.read(&mut buf)?;
371
372                if n == 0 {
373                    Ok(())
374                } else {
375                    Err(ConfigReadError::InvalidState(String::from("Data loaded , but the input has not reached the end.")))
376                }
377            },
378            None => {
379                Err(ConfigReadError::InvalidState(String::from(
380                    "File does not exist yet.")))
381            }
382        }
383    }
384}
385impl SaveToFile<f64> for BinFilePersistence<f64> {
386    fn save<P: AsRef<Path>>(&self,file:P) -> Result<(),io::Error> {
387        let mut bw = BufWriter::new(OpenOptions::new().write(true).create(true).open(file)?);
388
389        for u in self.data.iter() {
390            let mut buf = [0; 8];
391            let bits = u.to_bits();
392
393            buf[0] = (bits >> 56 & 0xff) as u8;
394            buf[1] = (bits >> 48 & 0xff) as u8;
395            buf[2] = (bits >> 40 & 0xff) as u8;
396            buf[3] = (bits >> 32 & 0xff) as u8;
397            buf[4] = (bits >> 24 & 0xff) as u8;
398            buf[5] = (bits >> 16 & 0xff) as u8;
399            buf[6] = (bits >> 8 & 0xff) as u8;
400            buf[7] = (bits & 0xff) as u8;
401
402            bw.write(&buf)?;
403        }
404
405        Ok(())
406    }
407}
408impl SaveToFile<f32> for BinFilePersistence<f32> {
409    fn save<P: AsRef<Path>>(&self,file:P) -> Result<(),io::Error> {
410        let mut bw = BufWriter::new(OpenOptions::new().write(true).create(true).open(file)?);
411
412        for u in self.data.iter() {
413            let mut buf = [0; 4];
414            let bits = u.to_bits();
415            buf[0] = (bits >> 24 & 0xff) as u8;
416            buf[1] = (bits >> 16 & 0xff) as u8;
417            buf[2] = (bits >> 8 & 0xff) as u8;
418            buf[3] = (bits & 0xff) as u8;
419
420            bw.write(&buf)?;
421        }
422
423        Ok(())
424    }
425}