mssf_core/
iter.rs

1// ------------------------------------------------------------
2// Copyright (c) Microsoft Corporation.  All rights reserved.
3// Licensed under the MIT License (MIT). See License.txt in the repo root for license information.
4// ------------------------------------------------------------
5
6// iterator implementation
7// Iter infrastructure to convert Fabric raw list into rust safe wrappers.
8// Raw lists needs to be wrapped in FabricListAccessor, and raw item needs to
9// implement From<T> trait to convert to rust safe struct, then the FabricIter
10// enables the mechanism to convert item one by one while iterating.
11// Currently this is not used in core apis, because core apis converts directly into Vec<T>
12// for simplicity. This is useful for user with high performance requirement.
13
14use std::marker::PhantomData;
15
16// Access fabric list metadata
17// T is the fabric raw type that needs to iterate through by pointer arithmetic
18pub trait FabricListAccessor<T> {
19    fn get_count(&self) -> u32;
20    fn get_first_item(&self) -> *const T;
21}
22
23// T is the raw fabric type
24// R is the safe type to convert to
25// O is the memory owner reference
26// R can be converted to T using the From trait
27pub struct FabricIter<'b, T, R, O>
28where
29    R: for<'a> std::convert::From<&'a T>,
30{
31    _owner: &'b O, // owns the memory that the curr ptr points to. Typically this is a COM obj.
32    count: u32,    // total
33    index: u32,
34    curr: *const T,
35    phantom: PhantomData<R>, // R is the converted type
36}
37
38impl<'b, T, R, O> FabricIter<'b, T, R, O>
39where
40    R: for<'a> std::convert::From<&'a T>,
41{
42    pub fn new(accessor: &impl FabricListAccessor<T>, owner: &'b O) -> Self {
43        let count = accessor.get_count();
44        let first = accessor.get_first_item();
45        Self {
46            count,
47            index: 0,
48            curr: first,
49            phantom: PhantomData {},
50            _owner: owner,
51        }
52    }
53}
54
55impl<T, R, O> Iterator for FabricIter<'_, T, R, O>
56where
57    R: for<'a> std::convert::From<&'a T>,
58{
59    type Item = R;
60    fn next(&mut self) -> Option<Self::Item> {
61        if self.index >= self.count {
62            return None;
63        }
64        // get the curr out
65        let raw = unsafe { self.curr.as_ref().unwrap() };
66
67        let res: R = raw.into();
68        self.index += 1;
69        self.curr = unsafe { self.curr.offset(1) };
70        Some(res)
71    }
72}
73
74/// Convert a raw pointer and length into a Vec of safe type.
75pub(crate) fn vec_from_raw_com<T, V>(len: usize, raw: *const T) -> Vec<V>
76where
77    V: for<'a> std::convert::From<&'a T>,
78{
79    if len == 0 || raw.is_null() {
80        return vec![];
81    }
82    if raw.is_aligned() {
83        unsafe {
84            std::slice::from_raw_parts(raw, len)
85                .iter()
86                .map(|x| x.into())
87                .collect()
88        }
89    } else {
90        // Sometimes SF COM ptr is not aligned, but is verified to be correct during testing.
91        // Ptr not aligned, need to copy one by one
92        let mut v = Vec::with_capacity(len);
93        for i in 0..len {
94            let p = unsafe { raw.add(i) };
95            let r = unsafe { p.as_ref().unwrap() };
96            v.push(r.into());
97        }
98        v
99    }
100}
101
102#[cfg(test)]
103mod test {
104
105    use super::{FabricIter, FabricListAccessor};
106
107    struct MyVal {
108        val: String,
109    }
110
111    struct MyVal2 {
112        val: String,
113    }
114
115    impl From<&MyVal> for MyVal2 {
116        fn from(value: &MyVal) -> Self {
117            Self {
118                val: value.val.clone() + "Suffix",
119            }
120        }
121    }
122
123    struct MyVec {
124        v: Vec<MyVal>,
125    }
126
127    impl FabricListAccessor<MyVal> for MyVec {
128        fn get_count(&self) -> u32 {
129            self.v.len() as u32
130        }
131
132        fn get_first_item(&self) -> *const MyVal {
133            self.v.as_ptr()
134        }
135    }
136
137    type MyVecIter<'a> = FabricIter<'a, MyVal, MyVal2, MyVec>;
138
139    impl MyVec {
140        fn get_iter(&self) -> MyVecIter<'_> {
141            MyVecIter::new(self, self)
142        }
143    }
144
145    #[test]
146    fn test_vector() {
147        let v = MyVec {
148            v: vec![
149                MyVal {
150                    val: "hi".to_string(),
151                },
152                MyVal {
153                    val: "hi2".to_string(),
154                },
155            ],
156        };
157
158        let it = v.get_iter();
159        let vv = it.collect::<Vec<_>>();
160        assert_eq!(vv.len(), 2);
161        assert_eq!(vv.first().unwrap().val, "hiSuffix");
162        assert_eq!(vv.last().unwrap().val, "hi2Suffix");
163    }
164}