1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
use alloc::alloc::{Allocator, Layout};
use alloc::collections::vec_deque::VecDeque;
use alloc::vec;
use alloc::vec::Vec;
use alloc::{alloc::AllocError, collections::TryReserveError};
use core::cmp;
use core::ops::Index;
use core::ptr::NonNull;

use custom_error::custom_error;

use crate::{IOAddr, PAddr, VAddr};

// custom error for the IOMemory
custom_error! {pub IOMemError
    OutOfMemory = "reached out of memory",
    NotYetImplemented = "feature not yet implemented"
}

impl From<TryReserveError> for IOMemError {
    fn from(_e: TryReserveError) -> Self {
        IOMemError::OutOfMemory
    }
}

///  TODO: get rid of this:
pub const KERNEL_BASE: u64 = 0x400000000000;

/// A trait to tag objects which a device needs to read or write over DMA.
pub trait DmaObject {
    fn paddr(&self) -> PAddr {
        PAddr::from(self as *const Self as *const () as u64) - PAddr::from(KERNEL_BASE)
    }

    fn vaddr(&self) -> VAddr {
        VAddr::from(self as *const Self as *const () as usize)
    }

    fn ioaddr(&self) -> IOAddr {
        IOAddr::from(self.paddr().as_u64())
    }
}

/// An allocator that (supposedly) backs memory accessible by devices.
#[derive(Debug, Default, Clone, Copy)]
pub struct DmaAllocator;

unsafe impl Allocator for DmaAllocator {
    /// Allocates IO memory.
    fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
        // TODO: ensure IOMMU stuff etc. here, for now:
        unsafe {
            // do the actual allocation, refer to the OS allocator
            let ptr: *mut u8 = alloc::alloc::alloc_zeroed(layout);
            if !ptr.is_null() {
                // wrap in in NonNull, remove option type
                let ptr_nonnull = NonNull::new(ptr).unwrap();
                // construct the NonNull slice for the return
                Ok(NonNull::slice_from_raw_parts(ptr_nonnull, layout.size()))
            } else {
                Err(AllocError)
            }
        }
    }

    /// Deallocates the previously allocated IO memory.
    unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
        // TODO: ensure IOMMU stuff, for now:
        let buf = ptr.as_ptr();
        alloc::alloc::dealloc(buf, layout);
    }
}

#[derive(Debug)]
/// Represents an IO buffer (data handed to/from device).
pub struct IOBuf {
    buf: Vec<u8, DmaAllocator>,
}

impl IOBuf {
    pub fn new(layout: Layout) -> Result<IOBuf, IOMemError> {
        // get the aligned buffer length
        // get the layouf for the allocation
        let allocator = DmaAllocator::default();
        let buf: Vec<u8, DmaAllocator> = Vec::with_capacity_in(layout.size(), allocator);
        let mut iobuf = IOBuf { buf };
        // call expand here to make sure the buffer has the full size
        iobuf.expand();
        // info!("IOBuf: new buffer of size {}!",iobuf.capacity());
        Ok(iobuf)
    }

    /// Fill buffer with as many 0 as capacity allows.
    pub fn expand(&mut self) {
        self.buf.resize(self.buf.capacity(), 0);
    }

    pub fn truncate(&mut self, new_len: usize) {
        self.buf.truncate(new_len)
    }

    /// Removes all buffer contents.
    pub fn clear(&mut self) {
        self.buf.clear();
    }

    /// Copy data from `src` into a given `offset` of the `IOBuf`.
    pub fn copy_in_at(&mut self, offset: usize, src: &[u8]) -> Result<usize, IOMemError> {
        // Currently we do not allow extending the buffer:
        let remaining_capacity = self.buf.capacity() - offset;
        let cnt = cmp::min(remaining_capacity, src.len());
        self.buf.resize(offset + cnt, 0);

        // copy the slice
        self.buf[offset..offset + cnt].copy_from_slice(&src[0..cnt]);

        Ok(cnt)
    }

    /// Copy raw data of size `len` into the buffer.
    pub fn copy_in(&mut self, src: &[u8]) -> Result<usize, IOMemError> {
        self.copy_in_at(0, src)
    }

    /// Copy data out of the IOBuf, starting at a given `offset` into `dst`.
    pub fn copy_out_at(&self, offset: usize, dst: &mut [u8]) -> Result<usize, IOMemError> {
        // of the offset is outside of the length of the vector then we
        if offset >= self.buf.len() {
            return Ok(0);
        }

        let cnt = cmp::min(self.buf.len() - offset, dst.len());
        // copy the slice
        dst[0..cnt].copy_from_slice(&self.buf[offset..offset + cnt]);
        Ok(cnt)
    }

    /// Copy the data (starting at 0) to `dst` slice.
    pub fn copy_out(&self, dst: &mut [u8]) -> Result<usize, IOMemError> {
        self.copy_out_at(0, dst)
    }

    /// Get a IOBuf contents as slice.
    pub fn as_slice(&self) -> &[u8] {
        self.buf.as_slice()
    }

    /// Get a IOBuf contents as mutable slice.
    pub fn as_mut_slice(&mut self) -> &mut [u8] {
        self.buf.as_mut_slice()
    }

    pub fn len(&self) -> usize {
        self.buf.len()
    }

    pub fn is_empty(&self) -> bool {
        self.buf.is_empty()
    }
}

/// implementation for the index operator [] on IOBuf
impl Index<usize> for IOBuf {
    /// The returned type after indexing.
    type Output = u8;

    /// Performs the indexing (`container[index]`) operation.
    #[inline]
    fn index(&self, index: usize) -> &Self::Output {
        &self.buf[index]
    }
}

impl DmaObject for IOBuf {
    /// Address of the IOBuf in main memory.
    fn paddr(&self) -> PAddr {
        PAddr::from(self.buf.as_ptr() as u64 - KERNEL_BASE)
    }

    /// Virtual address this buffer can be access by software.
    fn vaddr(&self) -> VAddr {
        VAddr::from(self.buf.as_ptr() as u64)
    }
}

/// A pool of buffers IOBuf's with the same size and for the same the device.
pub struct IOBufPool {
    /// Pool of buffers
    pool: Vec<IOBuf>,
    /// The allocator used for new buffers
    _allocator: DmaAllocator,
    /// The allocation layout of the buffers
    layout: Layout,
}

impl IOBufPool {
    pub fn new(len: usize, align: usize) -> Result<IOBufPool, IOMemError> {
        let layout = Layout::from_size_align(len, align).expect("Layout was invalid.");
        let allocator = DmaAllocator::default();

        Ok(IOBufPool {
            pool: vec![],
            _allocator: allocator,
            layout,
        })
    }

    pub fn get_buf(&mut self) -> Result<IOBuf, IOMemError> {
        if !self.pool.is_empty() {
            let mut buf = self.pool.pop().expect("should have a buffer here");
            buf.expand();
            buf.clear();
            Ok(buf)
        } else {
            IOBuf::new(self.layout)
        }
    }

    pub fn put_buf(&mut self, buf: IOBuf) {
        self.pool.push(buf)
    }
}

#[derive(Debug)]
/// An IO buffer.
pub struct IOBufChain {
    /// Completion queue index (set by driver),
    /// TODO: remove once no longer necessary?
    cqidx: usize,

    /// Check sum flags (set by driver on rx)
    pub csum_flags: u32,

    /// Checksum data (set by driver on rx)
    pub csum_data: u32,

    /// VLAN tag (set by device driver on rx)
    pub vtag: Option<u32>,

    /// Flags (to be used by device driver).
    pub flags: u32,

    /// Flow ID for RSS
    pub rss_flow_id: Option<usize>,

    /// RSS type
    pub rss_type: u32,

    /// The `IOBuf` fragments
    pub segments: VecDeque<IOBuf>,
}

impl IOBufChain {
    pub fn new(flags: u32, len: usize) -> Result<IOBufChain, IOMemError> {
        let mut vd = VecDeque::new();
        vd.try_reserve_exact(len)?;

        Ok(IOBufChain {
            cqidx: 0,
            flags,
            csum_flags: 0,
            csum_data: 0,
            vtag: None,
            rss_flow_id: None,
            rss_type: 0,
            segments: vd,
        })
    }

    /// Set meta-data provided by the driver
    pub fn set_meta_data(
        &mut self,
        total_len: usize,
        segments: usize,
        cqidx: usize,
        rss_flow_id: Option<usize>,
        rsstype: u32,
    ) {
        self.cqidx = cqidx;
        self.rss_flow_id = rss_flow_id;
        self.rss_type = rsstype;

        // Truncate unused segments to zero
        // count unused segments
        let mut remaining_bytes = total_len;
        let mut unused_segments = 0;
        for seg in self.segments.iter_mut() {
            if remaining_bytes == 0 {
                seg.truncate(0); // unused segment
                unused_segments += 1;
            }
            remaining_bytes -= seg.len();
        }

        assert_eq!(
            segments,
            self.segments.len() - unused_segments,
            "#Segments match"
        );
        assert_eq!(
            total_len,
            self.segments.iter().map(|s| s.len()).sum(),
            "total_len matches"
        );
    }

    pub fn append(&mut self, buf: IOBuf) {
        self.segments.push_back(buf);
    }
}

/// implementation for the index operator [] on IOBuf
impl Index<usize> for IOBufChain {
    /// The returned type after indexing.
    type Output = u8;

    /// Performs the indexing (`container[index]`) operation.
    fn index(&self, index: usize) -> &Self::Output {
        let mut cidx = index;
        let nseg = self.segments.len();
        for i in 0..nseg {
            let seglen = self.segments[i].len();
            if index < seglen {
                return &self.segments[i][cidx];
            }
            cidx -= seglen;
        }
        // error here?
        &self.segments[0][0]
    }
}