mod error;
#[cfg(feature = "json")]
mod functions;
pub mod v1;
use num_traits::Float;
use serde::{Deserialize, Deserializer, Serialize};
use crate::Network;
pub use error::Error;
#[cfg(feature = "json")]
pub(crate) use functions::*;
pub type Metadata = v1::Metadata;
pub type Data<T, E> = v1::Data<T, E>;
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "version", content = "network")]
pub enum PortableCGE<T: Float, E> {
#[serde(rename = "1")]
V1(v1::Data<T, E>),
}
impl<T: Float, E> PortableCGE<T, E> {
pub fn build(
self,
with_state: WithRecurrentState,
) -> Result<(Network<T>, CommonMetadata, Extra<E>), Error> {
match self {
Self::V1(e) => e.build(with_state),
}
}
}
impl<T: Float, E> From<v1::Data<T, E>> for PortableCGE<T, E> {
fn from(net: Data<T, E>) -> Self {
Self::V1(net)
}
}
#[derive(Clone, Debug)]
pub struct CommonMetadata {
pub description: Option<String>,
}
impl CommonMetadata {
fn new(description: Option<String>) -> Self {
Self { description }
}
pub fn into_latest_version(self) -> Option<Metadata> {
Some(Metadata::new(self.description))
}
pub fn into_v1(self) -> v1::Metadata {
v1::Metadata::new(self.description)
}
}
#[derive(Clone, Copy, Debug)]
pub struct WithRecurrentState(pub bool);
pub trait EncodingVersion<T: Float, E>: Into<PortableCGE<T, E>> {
type Metadata: Into<CommonMetadata>;
#[allow(clippy::new_ret_no_self)]
fn new(
network: &Network<T>,
metadata: Self::Metadata,
extra: E,
with_state: WithRecurrentState,
) -> PortableCGE<T, E>;
fn build(
self,
with_state: WithRecurrentState,
) -> Result<(Network<T>, CommonMetadata, Extra<E>), Error>;
}
pub trait MetadataVersion<T: Float, E>: Into<CommonMetadata> + Sized {
type Data: EncodingVersion<T, E, Metadata = Self>;
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Extra<E> {
Ok(E),
#[serde(deserialize_with = "deserialize_ignore_any")]
Other,
}
impl<E> Extra<E> {
pub fn unwrap(self) -> E {
if let Self::Ok(data) = self {
data
} else {
panic!("called `unwrap` on an `Other` value");
}
}
pub fn is_ok(&self) -> bool {
matches!(self, Self::Ok(_))
}
pub fn is_other(&self) -> bool {
matches!(self, Self::Other)
}
}
fn deserialize_ignore_any<'de, D: Deserializer<'de>, T: Default>(
deserializer: D,
) -> Result<T, D::Error> {
serde::de::IgnoredAny::deserialize(deserializer).map(|_| T::default())
}
#[cfg(test)]
mod tests {
use std::fs::File;
use std::io::Read;
use super::*;
fn get_file_path(file_name: &str) -> String {
format!("{}/test_data/{}", env!("CARGO_MANIFEST_DIR"), file_name)
}
#[test]
fn test_extra() {
#[derive(Serialize, Deserialize)]
struct Foo {
x: i32,
y: [f64; 2],
}
let path = get_file_path("with_extra_data_v1.cge");
let (_, _, extra) =
Network::<f64>::load_file::<(), _>(&path, WithRecurrentState(false)).unwrap();
assert!(extra.is_other());
let (_, _, extra2) =
Network::<f64>::load_file::<Foo, _>(&path, WithRecurrentState(false)).unwrap();
assert!(extra2.is_ok());
}
#[test]
fn test_v1() {
let mut loaded_string = String::new();
let mut file = File::open(get_file_path("test_network_v1.cge")).unwrap();
file.read_to_string(&mut loaded_string).unwrap();
let (mut network, metadata, extra) =
Network::<f64>::load_str::<()>(&loaded_string, WithRecurrentState(true)).unwrap();
let metadata = metadata.into_v1();
let saved_string = network
.to_string(metadata, extra, WithRecurrentState(true))
.unwrap();
assert_eq!(loaded_string.trim(), saved_string.trim());
let (mut network2, _, _) =
Network::<f64>::load_str::<()>(&loaded_string, WithRecurrentState(false)).unwrap();
let inputs = &[1.0, 1.0];
assert_ne!(network.evaluate(inputs), network2.evaluate(inputs));
}
}