Skip to main content

bytesandbrains_core/proto/
conversions.rs

1use super::{DistanceProto, PeerProto, ProtoConversionError, TensorProto, DATA_TYPE_FLOAT};
2use crate::{
3    address::{Address, AddressBook},
4    embedding::{F32Distance, F32Embedding, Embedding},
5    peer::Peer,
6    peer_id::PeerId,
7};
8
9impl From<F32Distance> for DistanceProto {
10    fn from(dist: F32Distance) -> Self {
11        DistanceProto {
12            data_type: DATA_TYPE_FLOAT,
13            float_data: vec![dist.0],
14            int32_data: vec![],
15            string_data: vec![],
16            int64_data: vec![],
17        }
18    }
19}
20
21impl TryFrom<DistanceProto> for F32Distance {
22    type Error = ProtoConversionError;
23
24    fn try_from(proto: DistanceProto) -> Result<Self, ProtoConversionError> {
25        if proto.data_type != DATA_TYPE_FLOAT {
26            return Err(ProtoConversionError::InvalidDataType {
27                expected: DATA_TYPE_FLOAT,
28                actual: proto.data_type,
29            });
30        }
31        if proto.float_data.len() != 1 {
32            return Err(ProtoConversionError::ConversionFailed(
33                format!("Expected 1 float in DistanceProto, got {}", proto.float_data.len()),
34            ));
35        }
36        Ok(F32Distance(proto.float_data[0]))
37    }
38}
39
40impl<const L: usize> From<F32Embedding<L>> for TensorProto {
41    fn from(embedding: F32Embedding<L>) -> Self {
42        TensorProto {
43            dims: vec![L as i64],
44            data_type: DATA_TYPE_FLOAT,
45            float_data: embedding.0.to_vec(),
46            ..Default::default()
47        }
48    }
49}
50
51impl<const L: usize> TryFrom<TensorProto> for F32Embedding<L> {
52    type Error = ProtoConversionError;
53
54    fn try_from(proto: TensorProto) -> Result<Self, ProtoConversionError> {
55        if proto.data_type != DATA_TYPE_FLOAT {
56            return Err(ProtoConversionError::InvalidDataType {
57                expected: DATA_TYPE_FLOAT,
58                actual: proto.data_type,
59            });
60        }
61
62        let expected_dims = vec![L as i64];
63        if proto.dims != expected_dims {
64            return Err(ProtoConversionError::InvalidTensorShape {
65                expected: expected_dims,
66                actual: proto.dims,
67            });
68        }
69
70        if proto.float_data.len() != L {
71            return Err(ProtoConversionError::ConversionFailed(
72                format!("Expected {} floats in TensorProto, got {}", L, proto.float_data.len()),
73            ));
74        }
75
76        Ok(F32Embedding::from_slice(&proto.float_data))
77    }
78}
79
80impl From<PeerId> for Vec<u8> {
81    fn from(peer_id: PeerId) -> Self {
82        peer_id.to_bytes()
83    }
84}
85
86impl<A: Address> From<&Peer<A>> for PeerProto {
87    fn from(peer: &Peer<A>) -> Self {
88        PeerProto {
89            peer_id: peer.peer_id.to_bytes(),
90            addresses: peer.addresses.iter().map(|a| a.to_string()).collect(),
91        }
92    }
93}
94
95impl<A: Address> From<Peer<A>> for PeerProto {
96    fn from(peer: Peer<A>) -> Self {
97        PeerProto::from(&peer)
98    }
99}
100
101impl<A: Address> TryFrom<PeerProto> for Peer<A> {
102    type Error = ProtoConversionError;
103
104    fn try_from(proto: PeerProto) -> Result<Self, ProtoConversionError> {
105        let peer_id = PeerId::from_slice(&proto.peer_id);
106        let max_size = proto.addresses.len();
107        let addresses = addresses_from_proto(proto.addresses, max_size)?;
108        Ok(Peer { peer_id, addresses })
109    }
110}
111
112pub fn addresses_to_proto<A: Address>(addresses: &AddressBook<A>) -> Vec<String> {
113    addresses.iter().map(|a| a.to_string()).collect()
114}
115
116pub fn addresses_from_proto<A: Address>(
117    addresses: Vec<String>,
118    max_size: usize,
119) -> Result<AddressBook<A>, ProtoConversionError> {
120    if addresses.is_empty() {
121        return Err(ProtoConversionError::ConversionFailed(
122            "Empty addresses".to_string(),
123        ));
124    }
125
126    let first: A = addresses[0].parse().map_err(|_| {
127        ProtoConversionError::ConversionFailed(
128            format!("Failed to parse address: {}", addresses[0]),
129        )
130    })?;
131
132    let mut book = AddressBook::new(first, max_size);
133
134    for addr_str in addresses.into_iter().skip(1) {
135        let addr: A = addr_str.parse().map_err(|_| {
136            ProtoConversionError::ConversionFailed(
137                format!("Failed to parse address: {}", addr_str),
138            )
139        })?;
140        book.seen(addr);
141    }
142
143    Ok(book)
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    #[test]
151    fn test_f32_distance_proto_roundtrip() {
152        let dist = F32Distance(42.5);
153        let proto: DistanceProto = dist.into();
154        let recovered: F32Distance = proto.try_into().unwrap();
155        assert_eq!(dist, recovered);
156    }
157
158    #[test]
159    fn test_f32_embedding_proto_roundtrip() {
160        let embedding = F32Embedding::<4>([1.0, 2.0, 3.0, 4.0]);
161        let proto: TensorProto = embedding.clone().into();
162        let recovered: F32Embedding<4> = proto.try_into().unwrap();
163        assert_eq!(embedding, recovered);
164    }
165
166    #[test]
167    fn test_peer_proto_roundtrip() {
168        let peer_id = PeerId::from_data("test-peer");
169        let addresses = AddressBook::new("192.168.1.1:8080".to_string(), 5);
170        let peer = Peer::new(peer_id, addresses);
171
172        let proto: PeerProto = peer.clone().into();
173        let recovered: Peer<String> = proto.try_into().unwrap();
174
175        assert_eq!(peer.peer_id, recovered.peer_id);
176        assert_eq!(peer.addresses.first(), recovered.addresses.first());
177    }
178
179    #[test]
180    fn test_peer_proto_multi_address_roundtrip() {
181        let peer_id = PeerId::from_data("test-peer");
182        let mut addresses = AddressBook::new("addr1".to_string(), 5);
183        addresses.seen("addr2".to_string());
184        addresses.seen("addr3".to_string());
185        let peer = Peer::new(peer_id, addresses);
186
187        let proto: PeerProto = peer.clone().into();
188        assert_eq!(proto.addresses.len(), 3);
189
190        let recovered: Peer<String> = proto.try_into().unwrap();
191        assert_eq!(recovered.addresses.len(), 3);
192    }
193
194    #[test]
195    fn test_peer_proto_empty_addresses_fails() {
196        let proto = PeerProto {
197            peer_id: vec![0u8; 64],
198            addresses: vec![],
199        };
200        let result: Result<Peer<String>, _> = proto.try_into();
201        assert!(result.is_err());
202    }
203}