flatbuffers_reflection/
safe_buffer.rs

1/*
2 * Copyright 2025 Google Inc. All rights reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::r#struct::Struct;
18use crate::reflection_generated::reflection::{Field, Schema};
19use crate::reflection_verifier::verify_with_options;
20use crate::{
21    get_any_field_float, get_any_field_float_in_struct, get_any_field_integer,
22    get_any_field_integer_in_struct, get_any_field_string, get_any_field_string_in_struct,
23    get_any_root, get_field_float, get_field_integer, get_field_string, get_field_struct,
24    get_field_struct_in_struct, get_field_table, get_field_vector, FlatbufferError,
25    FlatbufferResult, ForwardsUOffset,
26};
27use flatbuffers::{Follow, Table, Vector, VerifierOptions};
28use num_traits::float::Float;
29use num_traits::int::PrimInt;
30use num_traits::FromPrimitive;
31use std::collections::HashMap;
32
33#[derive(Debug)]
34pub struct SafeBuffer<'a> {
35    buf: &'a [u8],
36    schema: &'a Schema<'a>,
37    buf_loc_to_obj_idx: HashMap<usize, i32>,
38}
39
40impl<'a> SafeBuffer<'a> {
41    pub fn new(buf: &'a [u8], schema: &'a Schema) -> FlatbufferResult<Self> {
42        Self::new_with_options(buf, schema, &VerifierOptions::default())
43    }
44
45    pub fn new_with_options(
46        buf: &'a [u8],
47        schema: &'a Schema,
48        opts: &VerifierOptions,
49    ) -> FlatbufferResult<Self> {
50        let mut buf_loc_to_obj_idx = HashMap::new();
51        verify_with_options(&buf, schema, opts, &mut buf_loc_to_obj_idx)?;
52        Ok(SafeBuffer {
53            buf,
54            schema,
55            buf_loc_to_obj_idx,
56        })
57    }
58
59    /// Gets the root table in the buffer.
60    pub fn get_root(&self) -> SafeTable {
61        // SAFETY: the buffer was verified during construction.
62        let table = unsafe { get_any_root(self.buf) };
63
64        SafeTable {
65            safe_buf: self,
66            loc: table.loc(),
67        }
68    }
69
70    fn find_field_by_name(
71        &self,
72        buf_loc: usize,
73        field_name: &str,
74    ) -> FlatbufferResult<Option<Field>> {
75        Ok(self
76            .get_all_fields(buf_loc)?
77            .lookup_by_key(field_name, |field: &Field<'_>, key| {
78                field.key_compare_with_value(key)
79            }))
80    }
81
82    fn get_all_fields(&self, buf_loc: usize) -> FlatbufferResult<Vector<ForwardsUOffset<Field>>> {
83        if let Some(&obj_idx) = self.buf_loc_to_obj_idx.get(&buf_loc) {
84            let obj = if obj_idx == -1 {
85                self.schema.root_table().unwrap()
86            } else {
87                self.schema.objects().get(obj_idx.try_into()?)
88            };
89            Ok(obj.fields())
90        } else {
91            Err(FlatbufferError::InvalidTableOrStruct)
92        }
93    }
94}
95
96#[derive(Debug)]
97pub struct SafeTable<'a> {
98    safe_buf: &'a SafeBuffer<'a>,
99    loc: usize,
100}
101
102impl<'a> SafeTable<'a> {
103    /// Gets an integer table field given its exact type. Returns default integer value if the field is not set. Returns [None] if no default value is found. Returns error if
104    /// the table doesn't match the buffer or
105    /// the [field_name] doesn't match the table or
106    /// the field type doesn't match.
107    pub fn get_field_integer<T: for<'b> Follow<'b, Inner = T> + PrimInt + FromPrimitive>(
108        &self,
109        field_name: &str,
110    ) -> FlatbufferResult<Option<T>> {
111        if let Some(field) = self.safe_buf.find_field_by_name(self.loc, field_name)? {
112            // SAFETY: the buffer was verified during construction.
113            unsafe { get_field_integer::<T>(&Table::new(&self.safe_buf.buf, self.loc), &field) }
114        } else {
115            Err(FlatbufferError::FieldNotFound)
116        }
117    }
118
119    /// Gets a floating point table field given its exact type. Returns default float value if the field is not set. Returns [None] if no default value is found. Returns error if
120    /// the table doesn't match the buffer or
121    /// the [field_name] doesn't match the table or
122    /// the field type doesn't match.
123    pub fn get_field_float<T: for<'b> Follow<'b, Inner = T> + Float>(
124        &self,
125        field_name: &str,
126    ) -> FlatbufferResult<Option<T>> {
127        if let Some(field) = self.safe_buf.find_field_by_name(self.loc, field_name)? {
128            // SAFETY: the buffer was verified during construction.
129            unsafe { get_field_float::<T>(&Table::new(&self.safe_buf.buf, self.loc), &field) }
130        } else {
131            Err(FlatbufferError::FieldNotFound)
132        }
133    }
134
135    /// Gets a String table field given its exact type. Returns empty string if the field is not set. Returns [None] if no default value is found. Returns error if
136    /// the table doesn't match the buffer or
137    /// the [field_name] doesn't match the table or
138    /// the field type doesn't match.
139    pub fn get_field_string(&self, field_name: &str) -> FlatbufferResult<Option<&str>> {
140        if let Some(field) = self.safe_buf.find_field_by_name(self.loc, field_name)? {
141            // SAFETY: the buffer was verified during construction.
142            unsafe { get_field_string(&Table::new(&self.safe_buf.buf, self.loc), &field) }
143        } else {
144            Err(FlatbufferError::FieldNotFound)
145        }
146    }
147
148    /// Gets a [SafeStruct] table field given its exact type. Returns [None] if the field is not set. Returns error if
149    /// the table doesn't match the buffer or
150    /// the [field_name] doesn't match the table or
151    /// the field type doesn't match.
152    pub fn get_field_struct(&self, field_name: &str) -> FlatbufferResult<Option<SafeStruct<'a>>> {
153        if let Some(field) = self.safe_buf.find_field_by_name(self.loc, field_name)? {
154            // SAFETY: the buffer was verified during construction.
155            let optional_st =
156                unsafe { get_field_struct(&Table::new(&self.safe_buf.buf, self.loc), &field)? };
157            Ok(optional_st.map(|st| SafeStruct {
158                safe_buf: self.safe_buf,
159                loc: st.loc(),
160            }))
161        } else {
162            Err(FlatbufferError::FieldNotFound)
163        }
164    }
165
166    /// Gets a Vector table field given its exact type. Returns empty vector if the field is not set. Returns error if
167    /// the table doesn't match the buffer or
168    /// the [field_name] doesn't match the table or
169    /// the field type doesn't match.
170    pub fn get_field_vector<T: Follow<'a, Inner = T>>(
171        &self,
172        field_name: &str,
173    ) -> FlatbufferResult<Option<Vector<'a, T>>> {
174        if let Some(field) = self.safe_buf.find_field_by_name(self.loc, field_name)? {
175            // SAFETY: the buffer was verified during construction.
176            unsafe { get_field_vector(&Table::new(&self.safe_buf.buf, self.loc), &field) }
177        } else {
178            Err(FlatbufferError::FieldNotFound)
179        }
180    }
181
182    /// Gets a [SafeTable] table field given its exact type. Returns [None] if the field is not set. Returns error if
183    /// the table doesn't match the buffer or
184    /// the [field_name] doesn't match the table or
185    /// the field type doesn't match.
186    pub fn get_field_table(&self, field_name: &str) -> FlatbufferResult<Option<SafeTable<'a>>> {
187        if let Some(field) = self.safe_buf.find_field_by_name(self.loc, field_name)? {
188            // SAFETY: the buffer was verified during construction.
189            let optional_table =
190                unsafe { get_field_table(&Table::new(&self.safe_buf.buf, self.loc), &field)? };
191            Ok(optional_table.map(|t| SafeTable {
192                safe_buf: self.safe_buf,
193                loc: t.loc(),
194            }))
195        } else {
196            Err(FlatbufferError::FieldNotFound)
197        }
198    }
199
200    /// Returns the value of any table field as a 64-bit int, regardless of what type it is. Returns default integer if the field is not set or error if
201    /// the value cannot be parsed as integer or
202    /// the table doesn't match the buffer or
203    /// the [field_name] doesn't match the table.
204    /// [num_traits](https://docs.rs/num-traits/latest/num_traits/cast/trait.NumCast.html) is used for number casting.
205    pub fn get_any_field_integer(&self, field_name: &str) -> FlatbufferResult<i64> {
206        if let Some(field) = self.safe_buf.find_field_by_name(self.loc, field_name)? {
207            // SAFETY: the buffer was verified during construction.
208            unsafe { get_any_field_integer(&Table::new(&self.safe_buf.buf, self.loc), &field) }
209        } else {
210            Err(FlatbufferError::FieldNotFound)
211        }
212    }
213
214    /// Returns the value of any table field as a 64-bit floating point, regardless of what type it is. Returns default float if the field is not set or error if
215    /// the value cannot be parsed as float or
216    /// the table doesn't match the buffer or
217    /// the [field_name] doesn't match the table.
218    pub fn get_any_field_float(&self, field_name: &str) -> FlatbufferResult<f64> {
219        if let Some(field) = self.safe_buf.find_field_by_name(self.loc, field_name)? {
220            // SAFETY: the buffer was verified during construction.
221            unsafe { get_any_field_float(&Table::new(&self.safe_buf.buf, self.loc), &field) }
222        } else {
223            Err(FlatbufferError::FieldNotFound)
224        }
225    }
226
227    /// Returns the string representation of any table field value (e.g. integer 123 is returned as "123"), regardless of what type it is. Returns empty string if the field is not set. Returns error if
228    /// the table doesn't match the buffer or
229    /// the [field_name] doesn't match the table.
230    pub fn get_any_field_string(&self, field_name: &str) -> FlatbufferResult<String> {
231        if let Some(field) = self.safe_buf.find_field_by_name(self.loc, field_name)? {
232            // SAFETY: the buffer was verified during construction.
233            unsafe {
234                Ok(get_any_field_string(
235                    &Table::new(&self.safe_buf.buf, self.loc),
236                    &field,
237                    self.safe_buf.schema,
238                ))
239            }
240        } else {
241            Err(FlatbufferError::FieldNotFound)
242        }
243    }
244}
245
246#[derive(Debug)]
247pub struct SafeStruct<'a> {
248    safe_buf: &'a SafeBuffer<'a>,
249    loc: usize,
250}
251
252impl<'a> SafeStruct<'a> {
253    /// Gets a [SafeStruct] struct field given its exact type. Returns error if
254    /// the struct doesn't match the buffer or
255    /// the [field_name] doesn't match the struct or
256    /// the field type doesn't match.
257    pub fn get_field_struct(&self, field_name: &str) -> FlatbufferResult<SafeStruct<'a>> {
258        if let Some(field) = self.safe_buf.find_field_by_name(self.loc, field_name)? {
259            // SAFETY: the buffer was verified during construction.
260            let st = unsafe {
261                get_field_struct_in_struct(&Struct::new(&self.safe_buf.buf, self.loc), &field)?
262            };
263            Ok(SafeStruct {
264                safe_buf: self.safe_buf,
265                loc: st.loc(),
266            })
267        } else {
268            Err(FlatbufferError::FieldNotFound)
269        }
270    }
271
272    /// Returns the value of any struct field as a 64-bit int, regardless of what type it is. Returns error if
273    /// the struct doesn't match the buffer or
274    /// the [field_name] doesn't match the struct or
275    /// the value cannot be parsed as integer.
276    pub fn get_any_field_integer(&self, field_name: &str) -> FlatbufferResult<i64> {
277        if let Some(field) = self.safe_buf.find_field_by_name(self.loc, field_name)? {
278            // SAFETY: the buffer was verified during construction.
279            unsafe {
280                get_any_field_integer_in_struct(&Struct::new(&self.safe_buf.buf, self.loc), &field)
281            }
282        } else {
283            Err(FlatbufferError::FieldNotFound)
284        }
285    }
286
287    /// Returns the value of any struct field as a 64-bit floating point, regardless of what type it is. Returns error if
288    /// the struct doesn't match the buffer or
289    /// the [field_name] doesn't match the struct or
290    /// the value cannot be parsed as float.
291    pub fn get_any_field_float(&self, field_name: &str) -> FlatbufferResult<f64> {
292        if let Some(field) = self.safe_buf.find_field_by_name(self.loc, field_name)? {
293            // SAFETY: the buffer was verified during construction.
294            unsafe {
295                get_any_field_float_in_struct(&Struct::new(&self.safe_buf.buf, self.loc), &field)
296            }
297        } else {
298            Err(FlatbufferError::FieldNotFound)
299        }
300    }
301
302    /// Returns the string representation of any struct field value (e.g. integer 123 is returned as "123"), regardless of what type it is. Returns error if
303    /// the struct doesn't match the buffer or
304    /// the [field_name] doesn't match the struct.
305    pub fn get_any_field_string(&self, field_name: &str) -> FlatbufferResult<String> {
306        if let Some(field) = self.safe_buf.find_field_by_name(self.loc, field_name)? {
307            // SAFETY: the buffer was verified during construction.
308            unsafe {
309                Ok(get_any_field_string_in_struct(
310                    &Struct::new(&self.safe_buf.buf, self.loc),
311                    &field,
312                    self.safe_buf.schema,
313                ))
314            }
315        } else {
316            Err(FlatbufferError::FieldNotFound)
317        }
318    }
319}