Skip to main content

luwen_api/chip/
hl_comms.rs

1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::error::PlatformError;
5
6use super::{AxiData, AxiError, ChipComms, ChipInterface};
7
8/// Convenience trait for high-level communication with an arbitrary chip.
9pub trait HlComms {
10    fn comms_obj(&self) -> (&dyn ChipComms, &dyn ChipInterface);
11
12    fn noc_read(
13        &self,
14        noc_id: u8,
15        x: u8,
16        y: u8,
17        addr: u64,
18        data: &mut [u8],
19    ) -> Result<(), Box<dyn std::error::Error>> {
20        let (arc_if, chip_if) = self.comms_obj();
21        arc_if.noc_read(chip_if, noc_id, x, y, addr, data)
22    }
23
24    fn noc_write(
25        &self,
26        noc_id: u8,
27        x: u8,
28        y: u8,
29        addr: u64,
30        data: &[u8],
31    ) -> Result<(), Box<dyn std::error::Error>> {
32        let (arc_if, chip_if) = self.comms_obj();
33        arc_if.noc_write(chip_if, noc_id, x, y, addr, data)
34    }
35
36    fn noc_multicast(
37        &self,
38        noc_id: u8,
39        start: (u8, u8),
40        end: (u8, u8),
41        addr: u64,
42        data: &[u8],
43    ) -> Result<(), Box<dyn std::error::Error>> {
44        let (_, chip_if) = self.comms_obj();
45        chip_if.noc_multicast(noc_id, start, end, addr, data)
46    }
47
48    fn noc_broadcast(
49        &self,
50        noc_id: u8,
51        addr: u64,
52        data: &[u8],
53    ) -> Result<(), Box<dyn std::error::Error>> {
54        let (arc_if, chip_if) = self.comms_obj();
55        arc_if.noc_broadcast(chip_if, noc_id, addr, data)
56    }
57
58    fn noc_read32(
59        &self,
60        noc_id: u8,
61        x: u8,
62        y: u8,
63        addr: u64,
64    ) -> Result<u32, Box<dyn std::error::Error>> {
65        let (arc_if, chip_if) = self.comms_obj();
66        arc_if.noc_read32(chip_if, noc_id, x, y, addr)
67    }
68
69    fn noc_write32(
70        &self,
71        noc_id: u8,
72        x: u8,
73        y: u8,
74        addr: u64,
75        value: u32,
76    ) -> Result<(), Box<dyn std::error::Error>> {
77        let (arc_if, chip_if) = self.comms_obj();
78        arc_if.noc_write32(chip_if, noc_id, x, y, addr, value)
79    }
80
81    fn noc_broadcast32(
82        &self,
83        noc_id: u8,
84        addr: u64,
85        value: u32,
86    ) -> Result<(), Box<dyn std::error::Error>> {
87        let (arc_if, chip_if) = self.comms_obj();
88        arc_if.noc_broadcast32(chip_if, noc_id, addr, value)
89    }
90
91    fn axi_read(&self, addr: u64, data: &mut [u8]) -> Result<(), Box<dyn std::error::Error>> {
92        let (arc_if, chip_if) = self.comms_obj();
93        arc_if.axi_read(chip_if, addr, data)
94    }
95
96    fn axi_write(&self, addr: u64, data: &[u8]) -> Result<(), Box<dyn std::error::Error>> {
97        let (arc_if, chip_if) = self.comms_obj();
98        arc_if.axi_write(chip_if, addr, data)
99    }
100
101    fn axi_read32(&self, addr: u64) -> Result<u32, Box<dyn std::error::Error>> {
102        let (arc_if, chip_if) = self.comms_obj();
103        arc_if.axi_read32(chip_if, addr)
104    }
105
106    fn axi_write32(&self, addr: u64, value: u32) -> Result<(), Box<dyn std::error::Error>> {
107        let (arc_if, chip_if) = self.comms_obj();
108        arc_if.axi_write32(chip_if, addr, value)
109    }
110}
111
112#[inline]
113fn right_shift(existing: &mut [u8], shift: u32) {
114    let byte_shift = shift as usize / 8;
115    let bit_shift = shift as usize % 8;
116
117    if shift as usize >= existing.len() * 8 {
118        for o in existing {
119            *o = 0;
120        }
121        return;
122    }
123
124    if byte_shift > 0 {
125        for index in 0..existing.len() {
126            existing[index] = *existing.get(index + byte_shift).unwrap_or(&0);
127        }
128    }
129
130    if bit_shift > 0 {
131        let mut carry = 0;
132        for i in (0..existing.len()).rev() {
133            let next_carry = (existing[i] & ((1 << bit_shift) - 1)) << (8 - bit_shift);
134            existing[i] = (existing[i] >> bit_shift) | carry;
135            carry = next_carry;
136        }
137    }
138}
139
140#[allow(dead_code)]
141#[inline]
142fn left_shift(existing: &mut [u8], shift: u32) {
143    let byte_shift = shift as usize / 8;
144    let bit_shift = shift as usize % 8;
145
146    if shift as usize >= existing.len() * 8 {
147        for o in existing {
148            *o = 0;
149        }
150        return;
151    }
152
153    if byte_shift > 0 {
154        for index in (0..existing.len()).rev() {
155            let shifted = if index < byte_shift {
156                0
157            } else {
158                existing[index - byte_shift]
159            };
160            existing[index] = shifted;
161        }
162    }
163
164    if bit_shift > 0 {
165        let mut carry = 0;
166        for i in (0..existing.len()).rev() {
167            let next_carry =
168                (existing[i] & (((1 << bit_shift) - 1) << (8 - bit_shift))) >> bit_shift;
169            existing[i] = (existing[i] << bit_shift) | carry;
170            carry = next_carry;
171        }
172    }
173}
174
175#[inline]
176fn mask_off(existing: &mut [u8], high_bit: u32) -> &mut [u8] {
177    let top_byte = high_bit as usize / 8;
178    let top_bit = high_bit % 8;
179
180    if top_byte < existing.len() {
181        existing[top_byte] &= (1 << top_bit) - 1;
182    }
183
184    let len = existing.len();
185    &mut existing[0..(top_byte + 1).min(len)]
186}
187
188/// Take a value and place it onto the existing value shifting by, `lower` and masking off at `upper`
189fn write_modify(existing: &mut [u8], value: &[u8], lower: u32, upper: u32) {
190    assert!(upper >= lower);
191    assert!(existing.len() * 8 > upper as usize);
192
193    let mut shift_count = upper - lower + 1;
194    let mut read_ptr = 0;
195    let mut write_ptr = lower / 8;
196    let write_shift = lower % 8;
197
198    let mut first_time_lengthen = write_shift;
199
200    let mut carry = existing[write_ptr as usize] & ((1 << write_shift) - 1);
201    while shift_count > 0 {
202        let to_write = (value.get(read_ptr as usize).copied().unwrap_or(0) << write_shift) | carry;
203        if write_shift > 0 {
204            carry = (value.get(read_ptr as usize).copied().unwrap_or(0) >> (8 - write_shift))
205                & ((1 << write_shift) - 1);
206        }
207
208        let write_count = shift_count.min(8 - first_time_lengthen) as u16;
209        let write_mask = ((1 << (write_count + first_time_lengthen as u16).min(8)) - 1) as u8;
210
211        first_time_lengthen = 0;
212
213        existing[write_ptr as usize] =
214            (to_write & write_mask) | (existing[write_ptr as usize] & !write_mask);
215
216        read_ptr += 1;
217        write_ptr += 1;
218
219        shift_count -= write_count as u32;
220    }
221}
222
223fn read_modify(existing: &mut [u8], lower: u32, upper: u32) -> &[u8] {
224    assert!(upper >= lower);
225    assert!(existing.len() * 8 > upper as usize);
226
227    right_shift(existing, lower);
228    &*mask_off(existing, upper - lower + 1)
229}
230
231/// These functions can' be stored as a fat pointer so they are split out here.
232/// There is a blanket implementation for all types that implement HlComms.
233pub trait HlCommsInterface: HlComms {
234    fn axi_translate(&self, addr: impl AsRef<str>) -> Result<AxiData, AxiError> {
235        let (arc_if, _) = self.comms_obj();
236
237        arc_if.axi_translate(addr.as_ref())
238    }
239
240    fn axi_read_field<'a>(
241        &self,
242        addr: &AxiData,
243        value: &'a mut [u8],
244    ) -> Result<&'a [u8], PlatformError> {
245        let (arc_if, chip_if) = self.comms_obj();
246
247        if value.len() < addr.size as usize {
248            return Err(AxiError::ReadBufferTooSmall)?;
249        }
250
251        arc_if.axi_read(chip_if, addr.addr, &mut value[..addr.size as usize])?;
252
253        let value = if let Some((lower, upper)) = addr.bits {
254            read_modify(value, lower, upper);
255
256            value
257        } else {
258            &mut value[..addr.size as usize]
259        };
260
261        Ok(&*value)
262    }
263
264    fn axi_write_field(&self, addr: &AxiData, value: &[u8]) -> Result<(), PlatformError> {
265        let (arc_if, chip_if) = self.comms_obj();
266
267        if value.len() < addr.size as usize {
268            return Err(AxiError::ReadBufferTooSmall)?;
269        }
270
271        if let Some((lower, upper)) = addr.bits {
272            let mut existing = vec![0u8; addr.size as usize];
273            arc_if.axi_read(chip_if, addr.addr, &mut existing)?;
274
275            write_modify(&mut existing, value, lower, upper);
276
277            arc_if.axi_write(chip_if, addr.addr, &existing)?;
278        } else {
279            // We are writing the full size of the field
280            arc_if.axi_write(chip_if, addr.addr, &value[..addr.size as usize])?;
281        };
282
283        Ok(())
284    }
285
286    fn axi_sread<'a>(
287        &self,
288        addr: impl AsRef<str>,
289        value: &'a mut [u8],
290    ) -> Result<&'a [u8], PlatformError> {
291        let (arc_if, _chip_if) = self.comms_obj();
292
293        let addr = addr.as_ref();
294        let addr = arc_if.axi_translate(addr)?;
295
296        self.axi_read_field(&addr, value)
297    }
298
299    fn axi_sread_to_vec(&self, addr: impl AsRef<str>) -> Result<Vec<u8>, PlatformError> {
300        let (arc_if, chip_if) = self.comms_obj();
301
302        let addr = addr.as_ref();
303
304        let addr = arc_if.axi_translate(addr)?;
305
306        let mut output = Vec::with_capacity(addr.size as usize);
307
308        let value: &mut [u8] = unsafe { std::mem::transmute(output.spare_capacity_mut()) };
309
310        arc_if.axi_read(chip_if, addr.addr, &mut value[..addr.size as usize])?;
311
312        unsafe {
313            output.set_len(addr.size as usize);
314        }
315
316        Ok(output)
317    }
318
319    fn axi_sread32(&self, addr: impl AsRef<str>) -> Result<u32, PlatformError> {
320        let mut value = [0; 4];
321
322        let value = self.axi_sread(addr, &mut value)?;
323
324        let mut output = 0;
325        for o in value.iter().rev() {
326            output <<= 8;
327            output |= *o as u32;
328        }
329
330        Ok(output)
331    }
332
333    fn axi_swrite(&self, addr: impl AsRef<str>, value: &[u8]) -> Result<(), PlatformError> {
334        let (arc_if, _chip_if) = self.comms_obj();
335
336        let addr = arc_if.axi_translate(addr.as_ref())?;
337
338        self.axi_write_field(&addr, value)
339    }
340
341    fn axi_swrite32(&self, addr: impl AsRef<str>, value: u32) -> Result<(), PlatformError> {
342        self.axi_swrite(addr, &value.to_le_bytes())
343    }
344}
345
346impl<T: HlComms> HlCommsInterface for T {}
347
348#[cfg(test)]
349mod test {
350    #[test]
351    fn test_read_modify() {
352        let mut a = [0, 1, 2, 3];
353
354        let a = super::read_modify(&mut a, 0, 31);
355
356        assert_eq!(a, vec![0, 1, 2, 3]);
357    }
358
359    #[test]
360    fn test_read_modify_bit() {
361        let mut a = [0, 1, 2, 3];
362
363        let a = super::read_modify(&mut a, 8, 8);
364
365        assert_eq!(a, &[1]);
366    }
367
368    #[test]
369    fn test_read_modify_bits() {
370        let mut a = vec![0, 1, 2, 3];
371
372        let a = super::read_modify(&mut a, 19, 25);
373
374        assert_eq!(a, &[96]);
375    }
376
377    #[test]
378    fn test_read_modify_top() {
379        let mut a = [0, 0, 0, 0x80];
380
381        let a = super::read_modify(&mut a, 31, 31);
382
383        assert_eq!(a, vec![1]);
384    }
385
386    #[test]
387    fn test_read_modify_bottom() {
388        let mut a = [0x1, 0, 0, 0];
389
390        let a = super::read_modify(&mut a, 0, 0);
391
392        assert_eq!(a, vec![1]);
393    }
394
395    #[test]
396    fn test_write_modify() {
397        let mut a = vec![0, 1, 2, 3];
398        let b = vec![0b110];
399
400        super::write_modify(&mut a, &b, 0, 31);
401
402        assert_eq!(a, vec![6, 0, 0, 0]);
403    }
404
405    #[test]
406    fn test_write_modify_bit_low() {
407        let mut a = vec![0, 1, 2, 3];
408        let b = vec![0b0];
409
410        // Check that we won't do anything to a low bit
411        super::write_modify(&mut a, &b, 16, 16);
412        assert_eq!(a, vec![0, 1, 2, 3]);
413
414        super::write_modify(&mut a, &b, 24, 24);
415        assert_eq!(a, vec![0, 1, 2, 2]);
416    }
417
418    #[test]
419    fn test_write_modify_bit_high() {
420        let mut a = vec![0, 1, 2, 3];
421        let b = vec![0b1];
422
423        // Check that we won't do anything to a high bit
424        super::write_modify(&mut a, &b, 25, 25);
425        assert_eq!(a, vec![0, 1, 2, 3]);
426
427        super::write_modify(&mut a, &b, 18, 18);
428
429        assert_eq!(a, vec![0, 1, 6, 3]);
430    }
431
432    #[test]
433    fn test_write_modify_bits() {
434        let mut a = vec![0, 1, 2, 3];
435        let b = vec![0b110];
436
437        super::write_modify(&mut a, &b, 13, 19);
438
439        assert_eq!(a, vec![0, 193, 0, 3]);
440    }
441
442    #[test]
443    fn test_write_modify_top() {
444        let mut a = vec![0, 0, 0, 0];
445        let b = vec![0b1];
446
447        super::write_modify(&mut a, &b, 31, 31);
448
449        assert_eq!(a, vec![0, 0, 0, 0x80]);
450    }
451}