1use bitstream::{BitReader, BitWriter};
4use schema::schema_hash;
5use wire::{PacketFlags, PacketHeader, SectionTag, WirePacket, WireSection};
6
7use crate::error::{CodecError, CodecResult};
8use crate::limits::CodecLimits;
9use crate::snapshot::write_section;
10use crate::types::SnapshotTick;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum CompactHeaderMode {
15 SessionV1 = 1,
17}
18
19impl CompactHeaderMode {
20 fn from_raw(raw: u8) -> Option<Self> {
21 match raw {
22 1 => Some(Self::SessionV1),
23 _ => None,
24 }
25 }
26}
27
28#[derive(Debug, Clone)]
30pub struct SessionState {
31 pub schema_hash: u64,
32 pub session_id: Option<u64>,
33 pub last_tick: SnapshotTick,
34 pub compact_mode: CompactHeaderMode,
35}
36
37pub fn encode_session_init_packet(
39 schema: &schema::Schema,
40 tick: SnapshotTick,
41 session_id: Option<u64>,
42 compact_mode: CompactHeaderMode,
43 limits: &CodecLimits,
44 out: &mut [u8],
45) -> CodecResult<usize> {
46 let mut offset = wire::HEADER_SIZE;
47 let body_len = write_section(
48 SectionTag::SessionInit,
49 &mut out[offset..],
50 limits,
51 |writer| encode_session_init_body(session_id, compact_mode, writer),
52 )?;
53 offset += body_len;
54
55 let payload_len = offset - wire::HEADER_SIZE;
56 let header = PacketHeader {
57 version: wire::VERSION,
58 flags: PacketFlags::session_init(),
59 schema_hash: schema_hash(schema),
60 tick: tick.raw(),
61 baseline_tick: 0,
62 payload_len: payload_len as u32,
63 };
64 wire::encode_header(&header, &mut out[..wire::HEADER_SIZE]).map_err(|_| {
65 CodecError::OutputTooSmall {
66 needed: wire::HEADER_SIZE,
67 available: out.len(),
68 }
69 })?;
70
71 Ok(offset)
72}
73
74fn encode_session_init_body(
75 session_id: Option<u64>,
76 compact_mode: CompactHeaderMode,
77 writer: &mut BitWriter<'_>,
78) -> CodecResult<()> {
79 writer.align_to_byte()?;
80 writer.write_u64_aligned(session_id.unwrap_or(0))?;
81 writer.write_u8_aligned(compact_mode as u8)?;
82 writer.align_to_byte()?;
83 Ok(())
84}
85
86pub fn decode_session_init_packet(
88 schema: &schema::Schema,
89 packet: &WirePacket<'_>,
90 limits: &CodecLimits,
91) -> CodecResult<SessionState> {
92 let header = packet.header;
93 if !header.flags.is_session_init() {
94 return Err(CodecError::SessionMissing);
95 }
96 if header.flags.is_full_snapshot() || header.flags.is_delta_snapshot() {
97 return Err(CodecError::SessionInitInvalid);
98 }
99 if header.baseline_tick != 0 {
100 return Err(CodecError::SessionInitInvalid);
101 }
102 let expected_hash = schema_hash(schema);
103 if header.schema_hash != expected_hash {
104 return Err(CodecError::SchemaMismatch {
105 expected: expected_hash,
106 found: header.schema_hash,
107 });
108 }
109
110 let mut init_section: Option<&WireSection<'_>> = None;
111 for section in &packet.sections {
112 match section.tag {
113 SectionTag::SessionInit => {
114 if init_section.is_some() {
115 return Err(CodecError::SessionInitInvalid);
116 }
117 init_section = Some(section);
118 }
119 _ => {
120 return Err(CodecError::UnexpectedSection {
121 section: section.tag,
122 });
123 }
124 }
125 }
126 let section = init_section.ok_or(CodecError::SessionInitInvalid)?;
127 let (session_id, compact_mode) = decode_session_init_body(section.body, limits)?;
128
129 Ok(SessionState {
130 schema_hash: header.schema_hash,
131 session_id,
132 last_tick: SnapshotTick::new(header.tick),
133 compact_mode,
134 })
135}
136
137fn decode_session_init_body(
138 body: &[u8],
139 limits: &CodecLimits,
140) -> CodecResult<(Option<u64>, CompactHeaderMode)> {
141 if body.len() > limits.max_section_bytes {
142 return Err(CodecError::LimitsExceeded {
143 kind: crate::error::LimitKind::SectionBytes,
144 limit: limits.max_section_bytes,
145 actual: body.len(),
146 });
147 }
148 let mut reader = BitReader::new(body);
149 reader.align_to_byte()?;
150 let session_id = reader.read_u64_aligned()?;
151 let mode = reader.read_u8_aligned()?;
152 reader.align_to_byte()?;
153 if reader.bits_remaining() != 0 {
154 return Err(CodecError::TrailingSectionData {
155 section: SectionTag::SessionInit,
156 remaining_bits: reader.bits_remaining(),
157 });
158 }
159 let compact_mode =
160 CompactHeaderMode::from_raw(mode).ok_or(CodecError::SessionUnsupportedMode { mode })?;
161 Ok((
162 if session_id == 0 {
163 None
164 } else {
165 Some(session_id)
166 },
167 compact_mode,
168 ))
169}
170
171pub fn decode_session_packet<'a>(
173 schema: &schema::Schema,
174 session: &mut SessionState,
175 bytes: &'a [u8],
176 wire_limits: &wire::Limits,
177) -> CodecResult<WirePacket<'a>> {
178 if session.schema_hash != schema_hash(schema) {
179 return Err(CodecError::SchemaMismatch {
180 expected: schema_hash(schema),
181 found: session.schema_hash,
182 });
183 }
184 let header =
185 wire::decode_session_header(bytes, session.last_tick.raw()).map_err(CodecError::Wire)?;
186 if header.tick <= session.last_tick.raw() {
187 return Err(CodecError::SessionOutOfOrder {
188 previous: session.last_tick.raw(),
189 current: header.tick,
190 });
191 }
192
193 let payload_start = header.header_len;
194 let payload_end = payload_start + header.payload_len as usize;
195 if payload_end > bytes.len() {
196 return Err(CodecError::Wire(wire::DecodeError::PayloadLengthMismatch {
197 header_len: header.payload_len,
198 actual_len: bytes.len().saturating_sub(payload_start),
199 }));
200 }
201 let payload = &bytes[payload_start..payload_end];
202 let sections = wire::decode_sections(payload, wire_limits).map_err(CodecError::Wire)?;
203
204 session.last_tick = SnapshotTick::new(header.tick);
205 let flags = if header.flags.is_full_snapshot() {
206 PacketFlags::full_snapshot()
207 } else {
208 PacketFlags::delta_snapshot()
209 };
210 Ok(WirePacket {
211 header: PacketHeader {
212 version: wire::VERSION,
213 flags,
214 schema_hash: session.schema_hash,
215 tick: header.tick,
216 baseline_tick: header.baseline_tick,
217 payload_len: header.payload_len,
218 },
219 sections,
220 })
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226 use crate::snapshot::{ComponentSnapshot, EntitySnapshot, FieldValue, Snapshot};
227 use crate::types::EntityId;
228 use schema::{ComponentDef, FieldCodec, FieldDef, FieldId, Schema};
229
230 fn schema_one_bool() -> Schema {
231 let component = ComponentDef::new(schema::ComponentId::new(1).unwrap())
232 .field(FieldDef::new(FieldId::new(1).unwrap(), FieldCodec::bool()));
233 Schema::new(vec![component]).unwrap()
234 }
235
236 #[test]
237 fn session_init_roundtrip() {
238 let schema = schema_one_bool();
239 let mut buf = [0u8; 128];
240 let bytes = encode_session_init_packet(
241 &schema,
242 SnapshotTick::new(5),
243 Some(42),
244 CompactHeaderMode::SessionV1,
245 &CodecLimits::for_testing(),
246 &mut buf,
247 )
248 .unwrap();
249 let packet = wire::decode_packet(&buf[..bytes], &wire::Limits::for_testing()).unwrap();
250 let session =
251 decode_session_init_packet(&schema, &packet, &CodecLimits::for_testing()).unwrap();
252 assert_eq!(session.session_id, Some(42));
253 assert_eq!(session.last_tick.raw(), 5);
254 }
255
256 #[test]
257 fn session_decode_compact_packet() {
258 let schema = schema_one_bool();
259 let baseline = Snapshot {
260 tick: SnapshotTick::new(10),
261 entities: vec![EntitySnapshot {
262 id: EntityId::new(1),
263 components: vec![ComponentSnapshot {
264 id: schema::ComponentId::new(1).unwrap(),
265 fields: vec![FieldValue::Bool(false)],
266 }],
267 }],
268 };
269 let current = Snapshot {
270 tick: SnapshotTick::new(11),
271 entities: vec![EntitySnapshot {
272 id: EntityId::new(1),
273 components: vec![ComponentSnapshot {
274 id: schema::ComponentId::new(1).unwrap(),
275 fields: vec![FieldValue::Bool(true)],
276 }],
277 }],
278 };
279 let mut session = SessionState {
280 schema_hash: schema_hash(&schema),
281 session_id: Some(1),
282 last_tick: baseline.tick,
283 compact_mode: CompactHeaderMode::SessionV1,
284 };
285 let mut buf = [0u8; 256];
286 let bytes = crate::delta::encode_delta_snapshot_for_client_session_with_scratch(
287 &schema,
288 current.tick,
289 baseline.tick,
290 &baseline,
291 ¤t,
292 &CodecLimits::for_testing(),
293 &mut crate::scratch::CodecScratch::default(),
294 &mut session.last_tick,
295 &mut buf,
296 )
297 .unwrap();
298 let packet = decode_session_packet(
299 &schema,
300 &mut session,
301 &buf[..bytes],
302 &wire::Limits::for_testing(),
303 )
304 .unwrap();
305 assert!(packet.header.flags.is_delta_snapshot());
306 }
307}