1use std::io;
2
3use async_trait::async_trait;
4use bytes::Bytes;
5use destream::en;
6use futures::TryStreamExt;
7use get_size::GetSize;
8use get_size_derive::*;
9use safecast::{as_type, AsType};
10use tokio::fs;
11use tokio_util::io::StreamReader;
12
13use tc_chain::ChainBlock;
14use tc_collection::{btree, tensor};
15use tc_scalar::Scalar;
16use tc_value::Link;
17use tcgeneric::Map;
18
19#[derive(Clone, GetSize)]
21pub enum DenseBuffer {
22 F32(tensor::Buffer<f32>),
23 F64(tensor::Buffer<f64>),
24 I16(tensor::Buffer<i16>),
25 I32(tensor::Buffer<i32>),
26 I64(tensor::Buffer<i64>),
27 U8(tensor::Buffer<u8>),
28 U16(tensor::Buffer<u16>),
29 U32(tensor::Buffer<u32>),
30 U64(tensor::Buffer<u64>),
31}
32
33as_type!(DenseBuffer, F32, tensor::Buffer<f32>);
34as_type!(DenseBuffer, F64, tensor::Buffer<f64>);
35as_type!(DenseBuffer, I16, tensor::Buffer<i16>);
36as_type!(DenseBuffer, I32, tensor::Buffer<i32>);
37as_type!(DenseBuffer, I64, tensor::Buffer<i64>);
38as_type!(DenseBuffer, U8, tensor::Buffer<u8>);
39as_type!(DenseBuffer, U16, tensor::Buffer<u16>);
40as_type!(DenseBuffer, U32, tensor::Buffer<u32>);
41as_type!(DenseBuffer, U64, tensor::Buffer<u64>);
42
43impl<'en> en::ToStream<'en> for DenseBuffer {
44 fn to_stream<E: en::Encoder<'en>>(&'en self, encoder: E) -> Result<E::Ok, E::Error> {
45 match self {
46 Self::F32(this) => this.to_stream(encoder),
47 Self::F64(this) => this.to_stream(encoder),
48 Self::I16(this) => this.to_stream(encoder),
49 Self::I32(this) => this.to_stream(encoder),
50 Self::I64(this) => this.to_stream(encoder),
51 Self::U8(this) => this.to_stream(encoder),
52 Self::U16(this) => this.to_stream(encoder),
53 Self::U32(this) => this.to_stream(encoder),
54 Self::U64(this) => this.to_stream(encoder),
55 }
56 }
57}
58
59#[derive(Clone, GetSize)]
61pub enum CacheBlock {
62 BTree(btree::Node),
63 Chain(ChainBlock),
64 Class((Link, Map<Scalar>)),
65 Library(Map<Scalar>),
66 Sparse(tensor::Node),
67 Dense(DenseBuffer),
68}
69
70#[async_trait]
71impl<'en> tc_transact::fs::FileSave<'en> for CacheBlock {
72 async fn save(&'en self, file: &mut fs::File) -> Result<u64, io::Error> {
73 match self {
74 Self::BTree(node) => persist(node, file).await,
75 Self::Chain(block) => persist(block, file).await,
76 Self::Class(class) => persist(class, file).await,
77 Self::Library(library) => persist(library, file).await,
78 Self::Dense(dense) => persist(dense, file).await,
79 Self::Sparse(sparse) => persist(sparse, file).await,
80 }
81 }
82}
83
84as_type!(CacheBlock, BTree, btree::Node);
85as_type!(CacheBlock, Chain, ChainBlock);
86as_type!(CacheBlock, Class, (Link, Map<Scalar>));
87as_type!(CacheBlock, Library, Map<Scalar>);
88as_type!(CacheBlock, Sparse, tensor::Node);
89
90macro_rules! as_dense_type {
91 ($t:ty) => {
92 impl AsType<tensor::Buffer<$t>> for CacheBlock {
93 fn as_type(&self) -> Option<&tensor::Buffer<$t>> {
94 if let Self::Dense(block) = self {
95 block.as_type()
96 } else {
97 None
98 }
99 }
100
101 fn as_type_mut(&mut self) -> Option<&mut tensor::Buffer<$t>> {
102 if let Self::Dense(block) = self {
103 block.as_type_mut()
104 } else {
105 None
106 }
107 }
108
109 fn into_type(self) -> Option<tensor::Buffer<$t>> {
110 if let Self::Dense(block) = self {
111 block.into_type()
112 } else {
113 None
114 }
115 }
116 }
117
118 impl From<tensor::Buffer<$t>> for CacheBlock {
119 fn from(buffer: tensor::Buffer<$t>) -> Self {
120 Self::Dense(buffer.into())
121 }
122 }
123 };
124}
125
126as_dense_type!(f32);
127as_dense_type!(f64);
128as_dense_type!(i16);
129as_dense_type!(i32);
130as_dense_type!(i64);
131as_dense_type!(u8);
132as_dense_type!(u16);
133as_dense_type!(u32);
134as_dense_type!(u64);
135
136async fn persist<'en, T: en::ToStream<'en>>(
137 data: &'en T,
138 file: &mut fs::File,
139) -> Result<u64, io::Error> {
140 let encoded = tbon::en::encode(data)
141 .map_err(|cause| io::Error::new(io::ErrorKind::InvalidData, cause))?;
142
143 let mut reader = StreamReader::new(
144 encoded
145 .map_ok(Bytes::from)
146 .map_err(|cause| io::Error::new(io::ErrorKind::InvalidData, cause)),
147 );
148
149 tokio::io::copy(&mut reader, file).await
150}