Skip to main content

fory_core/resolver/
ref_resolver.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::buffer::{Reader, Writer};
19use crate::error::Error;
20use num_enum::TryFromPrimitive;
21use std::any::Any;
22use std::collections::HashMap;
23use std::rc::Rc;
24use std::sync::Arc;
25
26#[derive(Debug, TryFromPrimitive)]
27#[repr(i8)]
28pub enum RefFlag {
29    Null = -3,
30    // Ref indicates that object is a not-null value.
31    // We don't use another byte to indicate REF, so that we can save one byte.
32    Ref = -2,
33    // NotNullValueFlag indicates that the object is a non-null value.
34    NotNullValue = -1,
35    // RefValueFlag indicates that the object is a referencable and first read.
36    RefValue = 0,
37}
38
39/// Controls how reference and null flags are handled during serialization.
40///
41/// This enum combines nullable semantics and reference tracking into one parameter,
42/// enabling fine-grained control per type and per field:
43/// - `None` = non-nullable, no ref tracking (primitives)
44/// - `NullOnly` = nullable, no circular ref tracking
45/// - `Tracking` = nullable, with circular ref tracking (Rc/Arc/Weak)
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
47#[repr(u8)]
48pub enum RefMode {
49    /// Skip ref handling entirely. No ref/null flags are written/read.
50    /// Used for non-nullable primitives or when caller handles ref externally.
51    #[default]
52    None = 0,
53
54    /// Only null check without reference tracking.
55    /// Write: NullFlag (-3) for None, NotNullValueFlag (-1) for Some.
56    /// Read: Read flag and return ForyDefault on null.
57    NullOnly = 1,
58
59    /// Full reference tracking with circular reference support.
60    /// Write: Uses RefWriter which writes NullFlag, RefFlag+refId, or RefValueFlag.
61    /// Read: Uses RefReader with full reference resolution.
62    Tracking = 2,
63}
64
65impl RefMode {
66    /// Create RefMode from nullable and track_ref flags.
67    #[inline]
68    pub const fn from_flags(nullable: bool, track_ref: bool) -> Self {
69        match (nullable, track_ref) {
70            (false, false) => RefMode::None,
71            (true, false) => RefMode::NullOnly,
72            (_, true) => RefMode::Tracking,
73        }
74    }
75
76    /// Check if this mode reads/writes ref flags.
77    #[inline]
78    pub const fn has_ref_flag(self) -> bool {
79        !matches!(self, RefMode::None)
80    }
81
82    /// Check if this mode tracks circular references.
83    #[inline]
84    pub const fn tracks_refs(self) -> bool {
85        matches!(self, RefMode::Tracking)
86    }
87
88    /// Check if this mode handles nullable values.
89    #[inline]
90    pub const fn is_nullable(self) -> bool {
91        !matches!(self, RefMode::None)
92    }
93}
94
95/// Reference writer for tracking shared references during serialization.
96///
97/// RefWriter maintains a mapping from object pointer addresses to reference IDs,
98/// allowing the serialization system to detect when the same object is encountered
99/// multiple times and write a reference instead of serializing the object again.
100/// This enables proper handling of shared references and circular references.
101///
102/// # Examples
103///
104/// ```rust
105/// use fory_core::buffer::Writer;
106/// use fory_core::resolver::RefWriter;
107/// use std::rc::Rc;
108///
109/// let mut ref_writer = RefWriter::new();
110/// let mut buffer = vec![];
111/// let mut writer = Writer::from_buffer(&mut buffer);
112/// let rc = Rc::new(42);
113///
114/// // First encounter - returns false, indicating object should be serialized
115/// assert!(!ref_writer.try_write_rc_ref(&mut writer, &rc));
116///
117/// // Second encounter - returns true, indicating reference was written
118/// let rc2 = rc.clone();
119/// assert!(ref_writer.try_write_rc_ref(&mut writer, &rc2));
120/// ```
121#[derive(Default)]
122pub struct RefWriter {
123    /// Maps pointer addresses to reference IDs
124    refs: HashMap<usize, u32>,
125    /// Next reference ID to assign
126    next_ref_id: u32,
127}
128
129type UpdateCallback = Box<dyn FnOnce(&RefReader)>;
130
131impl RefWriter {
132    /// Creates a new RefWriter instance.
133    pub fn new() -> Self {
134        Self::default()
135    }
136
137    /// Attempt to write a reference for an `Rc<T>`.
138    ///
139    /// Returns true if a reference was written (indicating this object has been
140    /// seen before), false if this is the first occurrence and the object should
141    /// be serialized normally.
142    ///
143    /// # Arguments
144    ///
145    /// * `writer` - The writer to write reference information to
146    /// * `rc` - The Rc to check for reference tracking
147    ///
148    /// # Returns
149    ///
150    /// * `true` if a reference was written
151    /// * `false` if this is the first occurrence of the object
152    #[inline]
153    pub fn try_write_rc_ref<T: ?Sized>(&mut self, writer: &mut Writer, rc: &Rc<T>) -> bool {
154        let ptr_addr = Rc::as_ptr(rc) as *const () as usize;
155
156        if let Some(&ref_id) = self.refs.get(&ptr_addr) {
157            writer.write_i8(RefFlag::Ref as i8);
158            writer.write_var_u32(ref_id);
159            true
160        } else {
161            let ref_id = self.next_ref_id;
162            self.next_ref_id += 1;
163            self.refs.insert(ptr_addr, ref_id);
164            writer.write_i8(RefFlag::RefValue as i8);
165            false
166        }
167    }
168
169    /// Attempt to write a reference for an `Arc<T>`.
170    ///
171    /// Returns true if a reference was written (indicating this object has been
172    /// seen before), false if this is the first occurrence and the object should
173    /// be serialized normally.
174    ///
175    /// # Arguments
176    ///
177    /// * `writer` - The writer to write reference information to
178    /// * `arc` - The Arc to check for reference tracking
179    ///
180    /// # Returns
181    ///
182    /// * `true` if a reference was written
183    /// * `false` if this is the first occurrence of the object
184    #[inline]
185    pub fn try_write_arc_ref<T: ?Sized>(&mut self, writer: &mut Writer, arc: &Arc<T>) -> bool {
186        let ptr_addr = Arc::as_ptr(arc) as *const () as usize;
187
188        if let Some(&ref_id) = self.refs.get(&ptr_addr) {
189            // This object has been seen before, write a reference
190            writer.write_i8(RefFlag::Ref as i8);
191            writer.write_var_u32(ref_id);
192            true
193        } else {
194            // First time seeing this object, register it and return false
195            let ref_id = self.next_ref_id;
196            self.next_ref_id += 1;
197            self.refs.insert(ptr_addr, ref_id);
198            writer.write_i8(RefFlag::RefValue as i8);
199            false
200        }
201    }
202
203    /// Reserve a reference ID slot without storing anything.
204    ///
205    /// This is used for xlang compatibility where ALL objects (including struct values,
206    /// not just Rc/Arc) participate in reference tracking.
207    ///
208    /// # Returns
209    ///
210    /// The reserved reference ID
211    #[inline(always)]
212    pub fn reserve_ref_id(&mut self) -> u32 {
213        let ref_id = self.next_ref_id;
214        self.next_ref_id += 1;
215        ref_id
216    }
217
218    /// Clear all stored references.
219    ///
220    /// This is useful for reusing the RefWriter for multiple serialization operations.
221    #[inline(always)]
222    pub fn reset(&mut self) {
223        self.refs.clear();
224        self.next_ref_id = 0;
225    }
226}
227
228/// Reference reader for resolving shared references during deserialization.
229///
230/// RefReader maintains a vector of previously deserialized objects that can be
231/// referenced by ID. When a reference is encountered during deserialization,
232/// the RefReader can return the previously deserialized object instead of
233/// deserializing it again.
234///
235/// # Examples
236///
237/// ```rust
238/// use fory_core::resolver::RefReader;
239/// use std::rc::Rc;
240///
241/// let mut ref_reader = RefReader::new();
242/// let rc = Rc::new(42);
243///
244/// // Store an object for later reference
245/// let ref_id = ref_reader.store_rc_ref(rc.clone());
246///
247/// // Retrieve the object by reference ID
248/// let retrieved = ref_reader.get_rc_ref::<i32>(ref_id).unwrap();
249/// assert!(Rc::ptr_eq(&rc, &retrieved));
250/// ```
251#[derive(Default)]
252pub struct RefReader {
253    /// Vector to store boxed objects for reference resolution
254    refs: Vec<Box<dyn Any>>,
255    /// Callbacks to execute when references are resolved
256    callbacks: Vec<UpdateCallback>,
257}
258
259// danger but useful for multi-thread
260unsafe impl Send for RefReader {}
261unsafe impl Sync for RefReader {}
262
263impl RefReader {
264    /// Creates a new RefReader instance.
265    pub fn new() -> Self {
266        Self::default()
267    }
268
269    /// Reserve a reference ID slot without storing anything yet.
270    ///
271    /// Returns the reserved reference ID that will be used when storing the object later.
272    #[inline(always)]
273    pub fn reserve_ref_id(&mut self) -> u32 {
274        let ref_id = self.refs.len() as u32;
275        self.refs.push(Box::new(()));
276        ref_id
277    }
278
279    /// Store an `Rc<T>` at a previously reserved reference ID.
280    ///
281    /// # Arguments
282    ///
283    /// * `ref_id` - The reference ID that was reserved
284    /// * `rc` - The Rc to store
285    #[inline(always)]
286    pub fn store_rc_ref_at<T: 'static + ?Sized>(&mut self, ref_id: u32, rc: Rc<T>) {
287        self.refs[ref_id as usize] = Box::new(rc);
288    }
289
290    /// Store an `Rc<T>` for later reference resolution during deserialization.
291    ///
292    /// # Arguments
293    ///
294    /// * `rc` - The Rc to store for later reference
295    ///
296    /// # Returns
297    ///
298    /// The reference ID that can be used to retrieve this object later
299    #[inline(always)]
300    pub fn store_rc_ref<T: 'static + ?Sized>(&mut self, rc: Rc<T>) -> u32 {
301        let ref_id = self.refs.len() as u32;
302        self.refs.push(Box::new(rc));
303        ref_id
304    }
305
306    /// Store an `Arc<T>` at a previously reserved reference ID.
307    ///
308    /// # Arguments
309    ///
310    /// * `ref_id` - The reference ID that was reserved
311    /// * `arc` - The Arc to store
312    pub fn store_arc_ref_at<T: 'static + ?Sized>(&mut self, ref_id: u32, arc: Arc<T>) {
313        self.refs[ref_id as usize] = Box::new(arc);
314    }
315
316    /// Store an `Arc<T>` for later reference resolution during deserialization.
317    ///
318    /// # Arguments
319    ///
320    /// * `arc` - The Arc to store for later reference
321    ///
322    /// # Returns
323    ///
324    /// The reference ID that can be used to retrieve this object later
325    #[inline(always)]
326    pub fn store_arc_ref<T: 'static + ?Sized>(&mut self, arc: Arc<T>) -> u32 {
327        let ref_id = self.refs.len() as u32;
328        self.refs.push(Box::new(arc));
329        ref_id
330    }
331
332    /// Get an `Rc<T>` by reference ID during deserialization.
333    ///
334    /// # Arguments
335    ///
336    /// * `ref_id` - The reference ID returned by `store_rc_ref`
337    ///
338    /// # Returns
339    ///
340    /// * `Some(Rc<T>)` if the reference ID is valid and the type matches
341    /// * `None` if the reference ID is invalid or the type doesn't match
342    #[inline(always)]
343    pub fn get_rc_ref<T: 'static + ?Sized>(&self, ref_id: u32) -> Option<Rc<T>> {
344        let any_box = self.refs.get(ref_id as usize)?;
345        any_box.downcast_ref::<Rc<T>>().cloned()
346    }
347
348    /// Get an `Arc<T>` by reference ID during deserialization.
349    ///
350    /// # Arguments
351    ///
352    /// * `ref_id` - The reference ID returned by `store_arc_ref`
353    ///
354    /// # Returns
355    ///
356    /// * `Some(Arc<T>)` if the reference ID is valid and the type matches
357    /// * `None` if the reference ID is invalid or the type doesn't match
358    #[inline(always)]
359    pub fn get_arc_ref<T: 'static + ?Sized>(&self, ref_id: u32) -> Option<Arc<T>> {
360        let any_box = self.refs.get(ref_id as usize)?;
361        any_box.downcast_ref::<Arc<T>>().cloned()
362    }
363
364    /// Add a callback to be executed when weak references are resolved.
365    ///
366    /// # Arguments
367    ///
368    /// * `callback` - A closure that takes a reference to the RefReader
369    #[inline(always)]
370    pub fn add_callback(&mut self, callback: UpdateCallback) {
371        self.callbacks.push(callback);
372    }
373
374    /// Read a reference flag and determine what action to take.
375    ///
376    /// # Arguments
377    ///
378    /// * `reader` - The reader to read the reference flag from
379    ///
380    /// # Returns
381    ///
382    /// The RefFlag indicating what type of reference this is
383    ///
384    /// # Errors
385    ///
386    /// Errors if an invalid reference flag value is encountered
387    #[inline(always)]
388    pub fn read_ref_flag(&self, reader: &mut Reader) -> Result<RefFlag, Error> {
389        let flag_value = reader.read_i8()?;
390        Ok(match flag_value {
391            -3 => RefFlag::Null,
392            -2 => RefFlag::Ref,
393            -1 => RefFlag::NotNullValue,
394            0 => RefFlag::RefValue,
395            _ => Err(Error::invalid_ref(format!(
396                "Invalid reference flag: {}",
397                flag_value
398            )))?,
399        })
400    }
401
402    /// Read a reference ID from the reader.
403    ///
404    /// # Arguments
405    ///
406    /// * `reader` - The reader to read the reference ID from
407    ///
408    /// # Returns
409    ///
410    /// The reference ID as a u32
411    #[inline(always)]
412    pub fn read_ref_id(&self, reader: &mut Reader) -> Result<u32, Error> {
413        reader.read_var_u32()
414    }
415
416    /// Execute all pending callbacks to resolve weak pointer references.
417    ///
418    /// This should be called after deserialization completes to update any weak pointers
419    /// that referenced objects which were not yet available during deserialization.
420    #[inline(always)]
421    pub fn resolve_callbacks(&mut self) {
422        let callbacks = std::mem::take(&mut self.callbacks);
423        for callback in callbacks {
424            callback(self);
425        }
426    }
427
428    /// Clear all stored references and callbacks.
429    ///
430    /// This is useful for reusing the RefReader for multiple deserialization operations.
431    #[inline(always)]
432    pub fn reset(&mut self) {
433        self.resolve_callbacks();
434        self.refs.clear();
435        self.callbacks.clear();
436    }
437}