use super::{bin_config, PrecisionSettings, Recorder, RecorderError};
use alloc::vec::Vec;
use serde::{de::DeserializeOwned, Serialize};
pub trait BytesRecorder:
    Recorder<RecordArgs = (), RecordOutput = Vec<u8>, LoadArgs = Vec<u8>>
{
}
#[derive(new, Debug, Default, Clone)]
pub struct BinBytesRecorder<S: PrecisionSettings> {
    _settings: core::marker::PhantomData<S>,
}
impl<S: PrecisionSettings> BytesRecorder for BinBytesRecorder<S> {}
impl<S: PrecisionSettings> Recorder for BinBytesRecorder<S> {
    type Settings = S;
    type RecordArgs = ();
    type RecordOutput = Vec<u8>;
    type LoadArgs = Vec<u8>;
    fn save_item<I: Serialize>(
        &self,
        item: I,
        _args: Self::RecordArgs,
    ) -> Result<Self::RecordOutput, RecorderError> {
        Ok(bincode::serde::encode_to_vec(item, bin_config()).unwrap())
    }
    fn load_item<I: DeserializeOwned>(&self, args: Self::LoadArgs) -> Result<I, RecorderError> {
        let state = bincode::serde::decode_borrowed_from_slice(&args, bin_config()).unwrap();
        Ok(state)
    }
}
#[cfg(feature = "std")]
#[derive(new, Debug, Default, Clone)]
pub struct NamedMpkBytesRecorder<S: PrecisionSettings> {
    _settings: core::marker::PhantomData<S>,
}
#[cfg(feature = "std")]
impl<S: PrecisionSettings> BytesRecorder for NamedMpkBytesRecorder<S> {}
#[cfg(feature = "std")]
impl<S: PrecisionSettings> Recorder for NamedMpkBytesRecorder<S> {
    type Settings = S;
    type RecordArgs = ();
    type RecordOutput = Vec<u8>;
    type LoadArgs = Vec<u8>;
    fn save_item<I: Serialize>(
        &self,
        item: I,
        _args: Self::RecordArgs,
    ) -> Result<Self::RecordOutput, RecorderError> {
        rmp_serde::encode::to_vec_named(&item).map_err(|e| RecorderError::Unknown(e.to_string()))
    }
    fn load_item<I: DeserializeOwned>(&self, args: Self::LoadArgs) -> Result<I, RecorderError> {
        rmp_serde::decode::from_slice(&args).map_err(|e| RecorderError::Unknown(e.to_string()))
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    use crate::{module::Module, nn, record::FullPrecisionSettings, TestBackend};
    #[test]
    fn test_can_save_and_load_bin_format() {
        test_can_save_and_load(BinBytesRecorder::<FullPrecisionSettings>::default())
    }
    #[cfg(feature = "std")]
    #[test]
    fn test_can_save_and_load_named_mpk_format() {
        test_can_save_and_load(NamedMpkBytesRecorder::<FullPrecisionSettings>::default())
    }
    fn test_can_save_and_load<Recorder: BytesRecorder>(recorder: Recorder) {
        let model1 = create_model();
        let model2 = create_model();
        let bytes1 = recorder.record(model1.into_record(), ()).unwrap();
        let bytes2 = recorder.record(model2.clone().into_record(), ()).unwrap();
        let model2_after = model2.load_record(recorder.load(bytes1.clone()).unwrap());
        let bytes2_after = recorder.record(model2_after.into_record(), ()).unwrap();
        assert_ne!(bytes1, bytes2);
        assert_eq!(bytes1, bytes2_after);
    }
    pub fn create_model() -> nn::Linear<TestBackend> {
        nn::LinearConfig::new(32, 32).with_bias(true).init()
    }
}