fory_core/resolver/ref_resolver.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::buffer::{Reader, Writer};
19use crate::error::Error;
20use num_enum::TryFromPrimitive;
21use std::any::Any;
22use std::collections::HashMap;
23use std::rc::Rc;
24use std::sync::Arc;
25
26#[derive(Debug, TryFromPrimitive)]
27#[repr(i8)]
28pub enum RefFlag {
29 Null = -3,
30 // Ref indicates that object is a not-null value.
31 // We don't use another byte to indicate REF, so that we can save one byte.
32 Ref = -2,
33 // NotNullValueFlag indicates that the object is a non-null value.
34 NotNullValue = -1,
35 // RefValueFlag indicates that the object is a referencable and first read.
36 RefValue = 0,
37}
38
39/// Controls how reference and null flags are handled during serialization.
40///
41/// This enum combines nullable semantics and reference tracking into one parameter,
42/// enabling fine-grained control per type and per field:
43/// - `None` = non-nullable, no ref tracking (primitives)
44/// - `NullOnly` = nullable, no circular ref tracking
45/// - `Tracking` = nullable, with circular ref tracking (Rc/Arc/Weak)
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
47#[repr(u8)]
48pub enum RefMode {
49 /// Skip ref handling entirely. No ref/null flags are written/read.
50 /// Used for non-nullable primitives or when caller handles ref externally.
51 #[default]
52 None = 0,
53
54 /// Only null check without reference tracking.
55 /// Write: NullFlag (-3) for None, NotNullValueFlag (-1) for Some.
56 /// Read: Read flag and return ForyDefault on null.
57 NullOnly = 1,
58
59 /// Full reference tracking with circular reference support.
60 /// Write: Uses RefWriter which writes NullFlag, RefFlag+refId, or RefValueFlag.
61 /// Read: Uses RefReader with full reference resolution.
62 Tracking = 2,
63}
64
65impl RefMode {
66 /// Create RefMode from nullable and track_ref flags.
67 #[inline]
68 pub const fn from_flags(nullable: bool, track_ref: bool) -> Self {
69 match (nullable, track_ref) {
70 (false, false) => RefMode::None,
71 (true, false) => RefMode::NullOnly,
72 (_, true) => RefMode::Tracking,
73 }
74 }
75
76 /// Check if this mode reads/writes ref flags.
77 #[inline]
78 pub const fn has_ref_flag(self) -> bool {
79 !matches!(self, RefMode::None)
80 }
81
82 /// Check if this mode tracks circular references.
83 #[inline]
84 pub const fn tracks_refs(self) -> bool {
85 matches!(self, RefMode::Tracking)
86 }
87
88 /// Check if this mode handles nullable values.
89 #[inline]
90 pub const fn is_nullable(self) -> bool {
91 !matches!(self, RefMode::None)
92 }
93}
94
95/// Reference writer for tracking shared references during serialization.
96///
97/// RefWriter maintains a mapping from object pointer addresses to reference IDs,
98/// allowing the serialization system to detect when the same object is encountered
99/// multiple times and write a reference instead of serializing the object again.
100/// This enables proper handling of shared references and circular references.
101///
102/// # Examples
103///
104/// ```rust
105/// use fory_core::buffer::Writer;
106/// use fory_core::resolver::RefWriter;
107/// use std::rc::Rc;
108///
109/// let mut ref_writer = RefWriter::new();
110/// let mut buffer = vec![];
111/// let mut writer = Writer::from_buffer(&mut buffer);
112/// let rc = Rc::new(42);
113///
114/// // First encounter - returns false, indicating object should be serialized
115/// assert!(!ref_writer.try_write_rc_ref(&mut writer, &rc));
116///
117/// // Second encounter - returns true, indicating reference was written
118/// let rc2 = rc.clone();
119/// assert!(ref_writer.try_write_rc_ref(&mut writer, &rc2));
120/// ```
121#[derive(Default)]
122pub struct RefWriter {
123 /// Maps pointer addresses to reference IDs
124 refs: HashMap<usize, u32>,
125 /// Next reference ID to assign
126 next_ref_id: u32,
127}
128
129type UpdateCallback = Box<dyn FnOnce(&RefReader)>;
130
131impl RefWriter {
132 /// Creates a new RefWriter instance.
133 pub fn new() -> Self {
134 Self::default()
135 }
136
137 /// Attempt to write a reference for an `Rc<T>`.
138 ///
139 /// Returns true if a reference was written (indicating this object has been
140 /// seen before), false if this is the first occurrence and the object should
141 /// be serialized normally.
142 ///
143 /// # Arguments
144 ///
145 /// * `writer` - The writer to write reference information to
146 /// * `rc` - The Rc to check for reference tracking
147 ///
148 /// # Returns
149 ///
150 /// * `true` if a reference was written
151 /// * `false` if this is the first occurrence of the object
152 #[inline]
153 pub fn try_write_rc_ref<T: ?Sized>(&mut self, writer: &mut Writer, rc: &Rc<T>) -> bool {
154 let ptr_addr = Rc::as_ptr(rc) as *const () as usize;
155
156 if let Some(&ref_id) = self.refs.get(&ptr_addr) {
157 writer.write_i8(RefFlag::Ref as i8);
158 writer.write_var_u32(ref_id);
159 true
160 } else {
161 let ref_id = self.next_ref_id;
162 self.next_ref_id += 1;
163 self.refs.insert(ptr_addr, ref_id);
164 writer.write_i8(RefFlag::RefValue as i8);
165 false
166 }
167 }
168
169 /// Attempt to write a reference for an `Arc<T>`.
170 ///
171 /// Returns true if a reference was written (indicating this object has been
172 /// seen before), false if this is the first occurrence and the object should
173 /// be serialized normally.
174 ///
175 /// # Arguments
176 ///
177 /// * `writer` - The writer to write reference information to
178 /// * `arc` - The Arc to check for reference tracking
179 ///
180 /// # Returns
181 ///
182 /// * `true` if a reference was written
183 /// * `false` if this is the first occurrence of the object
184 #[inline]
185 pub fn try_write_arc_ref<T: ?Sized>(&mut self, writer: &mut Writer, arc: &Arc<T>) -> bool {
186 let ptr_addr = Arc::as_ptr(arc) as *const () as usize;
187
188 if let Some(&ref_id) = self.refs.get(&ptr_addr) {
189 // This object has been seen before, write a reference
190 writer.write_i8(RefFlag::Ref as i8);
191 writer.write_var_u32(ref_id);
192 true
193 } else {
194 // First time seeing this object, register it and return false
195 let ref_id = self.next_ref_id;
196 self.next_ref_id += 1;
197 self.refs.insert(ptr_addr, ref_id);
198 writer.write_i8(RefFlag::RefValue as i8);
199 false
200 }
201 }
202
203 /// Reserve a reference ID slot without storing anything.
204 ///
205 /// This is used for xlang compatibility where ALL objects (including struct values,
206 /// not just Rc/Arc) participate in reference tracking.
207 ///
208 /// # Returns
209 ///
210 /// The reserved reference ID
211 #[inline(always)]
212 pub fn reserve_ref_id(&mut self) -> u32 {
213 let ref_id = self.next_ref_id;
214 self.next_ref_id += 1;
215 ref_id
216 }
217
218 /// Clear all stored references.
219 ///
220 /// This is useful for reusing the RefWriter for multiple serialization operations.
221 #[inline(always)]
222 pub fn reset(&mut self) {
223 self.refs.clear();
224 self.next_ref_id = 0;
225 }
226}
227
228/// Reference reader for resolving shared references during deserialization.
229///
230/// RefReader maintains a vector of previously deserialized objects that can be
231/// referenced by ID. When a reference is encountered during deserialization,
232/// the RefReader can return the previously deserialized object instead of
233/// deserializing it again.
234///
235/// # Examples
236///
237/// ```rust
238/// use fory_core::resolver::RefReader;
239/// use std::rc::Rc;
240///
241/// let mut ref_reader = RefReader::new();
242/// let rc = Rc::new(42);
243///
244/// // Store an object for later reference
245/// let ref_id = ref_reader.store_rc_ref(rc.clone());
246///
247/// // Retrieve the object by reference ID
248/// let retrieved = ref_reader.get_rc_ref::<i32>(ref_id).unwrap();
249/// assert!(Rc::ptr_eq(&rc, &retrieved));
250/// ```
251#[derive(Default)]
252pub struct RefReader {
253 /// Vector to store boxed objects for reference resolution
254 refs: Vec<Box<dyn Any>>,
255 /// Callbacks to execute when references are resolved
256 callbacks: Vec<UpdateCallback>,
257}
258
259// danger but useful for multi-thread
260unsafe impl Send for RefReader {}
261unsafe impl Sync for RefReader {}
262
263impl RefReader {
264 /// Creates a new RefReader instance.
265 pub fn new() -> Self {
266 Self::default()
267 }
268
269 /// Reserve a reference ID slot without storing anything yet.
270 ///
271 /// Returns the reserved reference ID that will be used when storing the object later.
272 #[inline(always)]
273 pub fn reserve_ref_id(&mut self) -> u32 {
274 let ref_id = self.refs.len() as u32;
275 self.refs.push(Box::new(()));
276 ref_id
277 }
278
279 /// Store an `Rc<T>` at a previously reserved reference ID.
280 ///
281 /// # Arguments
282 ///
283 /// * `ref_id` - The reference ID that was reserved
284 /// * `rc` - The Rc to store
285 #[inline(always)]
286 pub fn store_rc_ref_at<T: 'static + ?Sized>(&mut self, ref_id: u32, rc: Rc<T>) {
287 self.refs[ref_id as usize] = Box::new(rc);
288 }
289
290 /// Store an `Rc<T>` for later reference resolution during deserialization.
291 ///
292 /// # Arguments
293 ///
294 /// * `rc` - The Rc to store for later reference
295 ///
296 /// # Returns
297 ///
298 /// The reference ID that can be used to retrieve this object later
299 #[inline(always)]
300 pub fn store_rc_ref<T: 'static + ?Sized>(&mut self, rc: Rc<T>) -> u32 {
301 let ref_id = self.refs.len() as u32;
302 self.refs.push(Box::new(rc));
303 ref_id
304 }
305
306 /// Store an `Arc<T>` at a previously reserved reference ID.
307 ///
308 /// # Arguments
309 ///
310 /// * `ref_id` - The reference ID that was reserved
311 /// * `arc` - The Arc to store
312 pub fn store_arc_ref_at<T: 'static + ?Sized>(&mut self, ref_id: u32, arc: Arc<T>) {
313 self.refs[ref_id as usize] = Box::new(arc);
314 }
315
316 /// Store an `Arc<T>` for later reference resolution during deserialization.
317 ///
318 /// # Arguments
319 ///
320 /// * `arc` - The Arc to store for later reference
321 ///
322 /// # Returns
323 ///
324 /// The reference ID that can be used to retrieve this object later
325 #[inline(always)]
326 pub fn store_arc_ref<T: 'static + ?Sized>(&mut self, arc: Arc<T>) -> u32 {
327 let ref_id = self.refs.len() as u32;
328 self.refs.push(Box::new(arc));
329 ref_id
330 }
331
332 /// Get an `Rc<T>` by reference ID during deserialization.
333 ///
334 /// # Arguments
335 ///
336 /// * `ref_id` - The reference ID returned by `store_rc_ref`
337 ///
338 /// # Returns
339 ///
340 /// * `Some(Rc<T>)` if the reference ID is valid and the type matches
341 /// * `None` if the reference ID is invalid or the type doesn't match
342 #[inline(always)]
343 pub fn get_rc_ref<T: 'static + ?Sized>(&self, ref_id: u32) -> Option<Rc<T>> {
344 let any_box = self.refs.get(ref_id as usize)?;
345 any_box.downcast_ref::<Rc<T>>().cloned()
346 }
347
348 /// Get an `Arc<T>` by reference ID during deserialization.
349 ///
350 /// # Arguments
351 ///
352 /// * `ref_id` - The reference ID returned by `store_arc_ref`
353 ///
354 /// # Returns
355 ///
356 /// * `Some(Arc<T>)` if the reference ID is valid and the type matches
357 /// * `None` if the reference ID is invalid or the type doesn't match
358 #[inline(always)]
359 pub fn get_arc_ref<T: 'static + ?Sized>(&self, ref_id: u32) -> Option<Arc<T>> {
360 let any_box = self.refs.get(ref_id as usize)?;
361 any_box.downcast_ref::<Arc<T>>().cloned()
362 }
363
364 /// Add a callback to be executed when weak references are resolved.
365 ///
366 /// # Arguments
367 ///
368 /// * `callback` - A closure that takes a reference to the RefReader
369 #[inline(always)]
370 pub fn add_callback(&mut self, callback: UpdateCallback) {
371 self.callbacks.push(callback);
372 }
373
374 /// Read a reference flag and determine what action to take.
375 ///
376 /// # Arguments
377 ///
378 /// * `reader` - The reader to read the reference flag from
379 ///
380 /// # Returns
381 ///
382 /// The RefFlag indicating what type of reference this is
383 ///
384 /// # Errors
385 ///
386 /// Errors if an invalid reference flag value is encountered
387 #[inline(always)]
388 pub fn read_ref_flag(&self, reader: &mut Reader) -> Result<RefFlag, Error> {
389 let flag_value = reader.read_i8()?;
390 Ok(match flag_value {
391 -3 => RefFlag::Null,
392 -2 => RefFlag::Ref,
393 -1 => RefFlag::NotNullValue,
394 0 => RefFlag::RefValue,
395 _ => Err(Error::invalid_ref(format!(
396 "Invalid reference flag: {}",
397 flag_value
398 )))?,
399 })
400 }
401
402 /// Read a reference ID from the reader.
403 ///
404 /// # Arguments
405 ///
406 /// * `reader` - The reader to read the reference ID from
407 ///
408 /// # Returns
409 ///
410 /// The reference ID as a u32
411 #[inline(always)]
412 pub fn read_ref_id(&self, reader: &mut Reader) -> Result<u32, Error> {
413 reader.read_var_u32()
414 }
415
416 /// Execute all pending callbacks to resolve weak pointer references.
417 ///
418 /// This should be called after deserialization completes to update any weak pointers
419 /// that referenced objects which were not yet available during deserialization.
420 #[inline(always)]
421 pub fn resolve_callbacks(&mut self) {
422 let callbacks = std::mem::take(&mut self.callbacks);
423 for callback in callbacks {
424 callback(self);
425 }
426 }
427
428 /// Clear all stored references and callbacks.
429 ///
430 /// This is useful for reusing the RefReader for multiple deserialization operations.
431 #[inline(always)]
432 pub fn reset(&mut self) {
433 self.resolve_callbacks();
434 self.refs.clear();
435 self.callbacks.clear();
436 }
437}