luwen_ref/
lib.rs

1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{
5    collections::HashMap,
6    sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard},
7};
8
9use error::LuwenError;
10use ttkmd_if::PciError;
11use luwen_if::{FnDriver, FnOptions};
12
13mod detect;
14pub mod error;
15mod wormhole;
16
17use wormhole::ethernet::{self, EthCommCoord};
18
19pub use detect::{detect_chips, detect_chips_fallible};
20pub use ttkmd_if::{DmaBuffer, DmaConfig, PciDevice, Tlb};
21
22#[derive(Clone)]
23pub struct ExtendedPciDeviceWrapper {
24    inner: Arc<RwLock<ExtendedPciDevice>>,
25}
26
27impl ExtendedPciDeviceWrapper {
28    pub fn borrow_mut(&self) -> RwLockWriteGuard<ExtendedPciDevice> {
29        self.inner.as_ref().write().unwrap()
30    }
31
32    pub fn borrow(&self) -> RwLockReadGuard<ExtendedPciDevice> {
33        self.inner.as_ref().read().unwrap()
34    }
35}
36
37pub struct ExtendedPciDevice {
38    pub device: PciDevice,
39
40    pub harvested_rows: u32,
41    pub grid_size_x: u8,
42    pub grid_size_y: u8,
43
44    pub eth_x: u8,
45    pub eth_y: u8,
46    pub command_q_addr: u32,
47    pub fake_block: bool,
48
49    pub default_tlb: u32,
50
51    pub ethernet_dma_buffer: HashMap<(u8, u8), DmaBuffer>,
52}
53
54impl ExtendedPciDevice {
55    pub fn setup_tlb(&mut self, index: u32, tlb: Tlb) -> Result<(u64, u64), PciError> {
56        ttkmd_if::tlb::setup_tlb(&mut self.device, index, tlb)
57    }
58
59    pub fn get_tlb(&self, index: u32) -> Result<Tlb, PciError> {
60        ttkmd_if::tlb::get_tlb(&self.device, index)
61    }
62
63    pub fn noc_write(&mut self, tlb_index: u32, addr: u64, data: &[u8]) -> Result<(), PciError> {
64        let mut written = 0;
65
66        let mut starting_tlb = self.get_tlb(tlb_index)?;
67
68        let len = data.len() as u64;
69
70        while written < len {
71            starting_tlb.local_offset = addr + written as u64;
72            let (bar_addr, slice_len) = self.setup_tlb(tlb_index, starting_tlb.clone())?;
73
74            let to_write = std::cmp::min(slice_len, len.saturating_sub(written));
75            self.write_block(
76                bar_addr as u32,
77                &data[written as usize..(written as usize + to_write as usize)],
78            )?;
79
80            written += to_write;
81        }
82
83        Ok(())
84    }
85
86    pub fn noc_read(&mut self, tlb_index: u32, addr: u64, data: &mut [u8]) -> Result<(), PciError> {
87        let mut read = 0;
88
89        let mut starting_tlb = self.get_tlb(tlb_index)?;
90
91        let len = data.len() as u64;
92
93        while read < len {
94            starting_tlb.local_offset = addr + read as u64;
95            let (bar_addr, slice_len) = self.setup_tlb(tlb_index, starting_tlb.clone())?;
96
97            let to_read = std::cmp::min(slice_len, len.saturating_sub(read));
98            self.read_block(
99                bar_addr as u32,
100                &mut data[read as usize..(read as usize + to_read as usize)],
101            )?;
102
103            read += to_read;
104        }
105
106        Ok(())
107    }
108
109    pub fn noc_write32(
110        &mut self,
111        tlb_index: u32,
112        noc_id: u8,
113        x: u8,
114        y: u8,
115        addr: u64,
116        data: u32,
117    ) -> Result<(), PciError> {
118        self.setup_tlb(
119            tlb_index,
120            Tlb {
121                x_end: x,
122                y_end: y,
123                noc_sel: noc_id,
124                ..Default::default()
125            },
126        )?;
127
128        self.noc_write(tlb_index, addr, &data.to_le_bytes())?;
129
130        Ok(())
131    }
132
133    pub fn noc_read32(
134        &mut self,
135        tlb_index: u32,
136        noc_id: u8,
137        x: u8,
138        y: u8,
139        addr: u64,
140    ) -> Result<u32, PciError> {
141        self.setup_tlb(
142            tlb_index,
143            Tlb {
144                x_end: x,
145                y_end: y,
146                noc_sel: noc_id,
147                ..Default::default()
148            },
149        )?;
150
151        let mut output = [0u8; 4];
152
153        self.noc_read(tlb_index, addr, &mut output)?;
154
155        Ok(u32::from_le_bytes(output))
156    }
157}
158
159impl ExtendedPciDevice {
160    pub fn open(pci_interface: usize) -> Result<ExtendedPciDeviceWrapper, ttkmd_if::PciOpenError> {
161        let device = PciDevice::open(pci_interface)?;
162
163        let (grid_size_x, grid_size_y) = match device.arch {
164            luwen_core::Arch::Grayskull => (13, 12),
165            luwen_core::Arch::Wormhole => (10, 12),
166            _ => unreachable!(),
167        };
168
169        Ok(ExtendedPciDeviceWrapper {
170            inner: Arc::new(RwLock::new(ExtendedPciDevice {
171                device,
172                harvested_rows: 0,
173                grid_size_x,
174                grid_size_y,
175                eth_x: 4,
176                eth_y: 6,
177                command_q_addr: 0,
178                fake_block: false,
179
180                default_tlb: 184,
181
182                ethernet_dma_buffer: HashMap::with_capacity(16),
183            })),
184        })
185    }
186
187    pub fn read_block(&mut self, addr: u32, data: &mut [u8]) -> Result<(), PciError> {
188        self.device.read_block(addr, data)
189    }
190
191    pub fn write_block(&mut self, addr: u32, data: &[u8]) -> Result<(), PciError> {
192        self.device.write_block(addr, data)
193    }
194}
195
196fn noc_write32(
197    device: &mut PciDevice,
198    tlb_index: u32,
199    noc_id: u8,
200    x: u8,
201    y: u8,
202    addr: u32,
203    data: u32,
204) -> Result<(), PciError> {
205    let (bar_addr, _slice_len) = ttkmd_if::tlb::setup_tlb(
206        device,
207        tlb_index,
208        Tlb {
209            local_offset: addr as u64,
210            x_end: x as u8,
211            y_end: y as u8,
212            noc_sel: noc_id,
213            mcast: false,
214            ..Default::default()
215        },
216    )?;
217
218    device.write_block(bar_addr as u32, data.to_le_bytes().as_slice())
219}
220
221fn noc_read32(
222    device: &mut PciDevice,
223    tlb_index: u32,
224    noc_id: u8,
225    x: u8,
226    y: u8,
227    addr: u32,
228) -> Result<u32, PciError> {
229    let (bar_addr, _slice_len) = ttkmd_if::tlb::setup_tlb(
230        device,
231        tlb_index,
232        Tlb {
233            local_offset: addr as u64,
234            x_end: x as u8,
235            y_end: y as u8,
236            noc_sel: noc_id,
237            mcast: false,
238            ..Default::default()
239        },
240    )?;
241
242    let mut data = [0u8; 4];
243    device.read_block(bar_addr as u32, &mut data)?;
244    Ok(u32::from_le_bytes(data))
245}
246
247pub fn comms_callback(
248    ud: &ExtendedPciDeviceWrapper,
249    op: FnOptions,
250) -> Result<(), Box<dyn std::error::Error>> {
251    Ok(comms_callback_inner(ud, op)?)
252}
253
254pub fn comms_callback_inner(
255    ud: &ExtendedPciDeviceWrapper,
256    op: FnOptions,
257) -> Result<(), LuwenError> {
258    match op {
259        FnOptions::Driver(op) => match op {
260            FnDriver::DeviceInfo(info) => {
261                let borrow = ud.borrow();
262                if !info.is_null() {
263                    unsafe {
264                        *info = Some(luwen_if::DeviceInfo {
265                            bus: borrow.device.physical.pci_bus,
266                            slot: borrow.device.physical.slot,
267                            function: borrow.device.physical.pci_function,
268                            domain: borrow.device.physical.pci_domain,
269
270                            interface_id: borrow.device.id as u32,
271
272                            vendor: borrow.device.physical.vendor_id,
273                            device_id: borrow.device.physical.device_id,
274                            bar_size: borrow.device.physical.bar_size_bytes,
275                        });
276                    }
277                }
278            }
279        },
280        FnOptions::Axi(op) => match op {
281            luwen_if::FnAxi::Read { addr, data, len } => {
282                if len > 0 {
283                    if len <= 4 {
284                        let output = ud.borrow_mut().device.read32(addr)?;
285                        let output = output.to_le_bytes();
286                        unsafe {
287                            data.copy_from_nonoverlapping(output.as_ptr(), len as usize);
288                        }
289                    } else {
290                        unsafe {
291                            ud.borrow_mut().read_block(
292                                addr,
293                                std::slice::from_raw_parts_mut(data, len as usize),
294                            )?
295                        };
296                    }
297                }
298            }
299            luwen_if::FnAxi::Write { addr, data, len } => {
300                if len > 0 {
301                    // Assuming here that u32 is our fundamental unit of transfer
302                    if len <= 4 {
303                        let to_write = if len == 4 {
304                            let slice = unsafe { std::slice::from_raw_parts(data, len as usize) };
305                            u32::from_le_bytes(slice.try_into().unwrap())
306                        } else {
307                            // We are reading less than a u32, so we need to read the existing value first
308                            // then writeback the new value with the lower len bytes replaced
309                            let value = ud.borrow_mut().device.read32(addr)?;
310                            let mut value = value.to_le_bytes();
311                            unsafe {
312                                value
313                                    .as_mut_ptr()
314                                    .copy_from_nonoverlapping(data, len as usize);
315                            }
316
317                            u32::from_le_bytes(value)
318                        };
319
320                        ud.borrow_mut().device.write32(addr, to_write)?;
321                    } else {
322                        unsafe {
323                            ud.borrow_mut()
324                                .write_block(addr, std::slice::from_raw_parts(data, len as usize))?
325                        };
326                    }
327                }
328            }
329        },
330        FnOptions::Noc(op) => match op {
331            luwen_if::FnNoc::Read {
332                noc_id,
333                x,
334                y,
335                addr,
336                data,
337                len,
338            } => {
339                let mut reader = ud.borrow_mut();
340                let reader: &mut ExtendedPciDevice = &mut reader;
341
342                reader.setup_tlb(
343                    reader.default_tlb,
344                    Tlb {
345                        local_offset: addr,
346                        x_end: x as u8,
347                        y_end: y as u8,
348                        noc_sel: noc_id,
349                        mcast: false,
350                        ..Default::default()
351                    },
352                )?;
353
354                reader.noc_read(reader.default_tlb, addr, unsafe {
355                    std::slice::from_raw_parts_mut(data, len as usize)
356                })?;
357            }
358            luwen_if::FnNoc::Write {
359                noc_id,
360                x,
361                y,
362                addr,
363                data,
364                len,
365            } => {
366                let mut writer = ud.borrow_mut();
367                let writer: &mut ExtendedPciDevice = &mut writer;
368
369                writer.setup_tlb(
370                    writer.default_tlb,
371                    Tlb {
372                        local_offset: addr,
373                        x_end: x as u8,
374                        y_end: y as u8,
375                        noc_sel: noc_id,
376                        mcast: false,
377                        ..Default::default()
378                    },
379                )?;
380
381                writer.noc_write(writer.default_tlb, addr, unsafe {
382                    std::slice::from_raw_parts(data, len as usize)
383                })?;
384            }
385            luwen_if::FnNoc::Broadcast {
386                noc_id,
387                addr,
388                data,
389                len,
390            } => {
391                let mut writer = ud.borrow_mut();
392                let writer: &mut ExtendedPciDevice = &mut writer;
393
394                let (x_start, y_start) = match writer.device.arch {
395                    luwen_core::Arch::Grayskull => (0, 0),
396                    luwen_core::Arch::Wormhole => (1, 0),
397                    luwen_core::Arch::Unknown(_) => todo!(),
398                };
399
400                writer.setup_tlb(
401                    writer.default_tlb,
402                    Tlb {
403                        local_offset: addr,
404                        x_start,
405                        y_start,
406                        x_end: writer.grid_size_x - 1,
407                        y_end: writer.grid_size_y - 1,
408                        noc_sel: noc_id,
409                        mcast: true,
410                        ..Default::default()
411                    },
412                )?;
413
414                writer.noc_write(writer.default_tlb, addr, unsafe {
415                    std::slice::from_raw_parts(data, len as usize)
416                })?;
417            }
418        },
419        FnOptions::Eth(op) => match op.rw {
420            luwen_if::FnNoc::Read {
421                noc_id,
422                x,
423                y,
424                addr,
425                data,
426                len,
427            } => {
428                let mut borrow = ud.borrow_mut();
429                let borrow: &mut ExtendedPciDevice = &mut borrow;
430
431                let eth_x = borrow.eth_x;
432                let eth_y = borrow.eth_y;
433
434                let command_q_addr = noc_read32(
435                    &mut borrow.device,
436                    borrow.default_tlb,
437                    0,
438                    eth_x,
439                    eth_y,
440                    0x170,
441                )?;
442                let fake_block = borrow.fake_block;
443
444                let default_tlb = borrow.default_tlb;
445                let read32 =
446                    |borrow: &mut _, addr| noc_read32(borrow, default_tlb, 0, eth_x, eth_y, addr);
447
448                let write32 = |borrow: &mut _, addr, data| {
449                    noc_write32(borrow, default_tlb, 0, eth_x, eth_y, addr, data)
450                };
451
452                let dma_buffer = {
453                    let key = (eth_x, eth_y);
454                    if !borrow.ethernet_dma_buffer.contains_key(&key) {
455                        // 1 MB buffer
456                        borrow
457                            .ethernet_dma_buffer
458                            .insert(key, borrow.device.allocate_dma_buffer(1 << 20)?);
459                    }
460
461                    // SAFETY: Can never get here without first inserting something into the hashmap
462                    unsafe { borrow.ethernet_dma_buffer.get_mut(&key).unwrap_unchecked() }
463                };
464
465                ethernet::fixup_queues(&mut borrow.device, read32, write32, command_q_addr)?;
466
467                if len <= 4 {
468                    let value = ethernet::eth_read32(
469                        &mut borrow.device,
470                        read32,
471                        write32,
472                        command_q_addr,
473                        EthCommCoord {
474                            coord: op.addr,
475                            noc_id,
476                            noc_x: x as u8,
477                            noc_y: y as u8,
478                            offset: addr,
479                        },
480                        std::time::Duration::from_secs(5 * 60),
481                    )?;
482
483                    let sl = unsafe { std::slice::from_raw_parts_mut(data, len as usize) };
484                    let vl = value.to_le_bytes();
485
486                    for (s, v) in sl.iter_mut().zip(vl.iter()) {
487                        *s = *v;
488                    }
489                } else {
490                    ethernet::block_read(
491                        &mut borrow.device,
492                        read32,
493                        write32,
494                        dma_buffer,
495                        command_q_addr,
496                        std::time::Duration::from_secs(5 * 60),
497                        fake_block,
498                        EthCommCoord {
499                            coord: op.addr,
500                            noc_id,
501                            noc_x: x as u8,
502                            noc_y: y as u8,
503                            offset: addr,
504                        },
505                        unsafe { std::slice::from_raw_parts_mut(data, len as usize) },
506                    )?;
507                }
508            }
509            luwen_if::FnNoc::Write {
510                noc_id,
511                x,
512                y,
513                addr,
514                data,
515                len,
516            } => {
517                let mut borrow = ud.borrow_mut();
518                let borrow: &mut ExtendedPciDevice = &mut borrow;
519
520                let eth_x = borrow.eth_x;
521                let eth_y = borrow.eth_y;
522
523                let command_q_addr =
524                    borrow.noc_read32(borrow.default_tlb, 0, eth_x, eth_y, 0x170)?;
525                let fake_block = borrow.fake_block;
526
527                let default_tlb = borrow.default_tlb;
528                let read32 =
529                    |borrow: &mut _, addr| noc_read32(borrow, default_tlb, 0, eth_x, eth_y, addr);
530
531                let write32 = |borrow: &mut _, addr, data| {
532                    noc_write32(borrow, default_tlb, 0, eth_x, eth_y, addr, data)
533                };
534
535                let dma_buffer = {
536                    let key = (eth_x, eth_y);
537                    if !borrow.ethernet_dma_buffer.contains_key(&key) {
538                        // 1 MB buffer
539                        borrow
540                            .ethernet_dma_buffer
541                            .insert(key, borrow.device.allocate_dma_buffer(1 << 20)?);
542                    }
543
544                    // SAFETY: Can never get here without first inserting something into the hashmap
545                    unsafe { borrow.ethernet_dma_buffer.get_mut(&key).unwrap_unchecked() }
546                };
547
548                ethernet::fixup_queues(&mut borrow.device, read32, write32, command_q_addr)?;
549
550                if len <= 4 {
551                    let sl = unsafe { std::slice::from_raw_parts(data, len as usize) };
552                    let mut value = 0u32;
553                    for s in sl.iter().rev() {
554                        value <<= 8;
555                        value |= *s as u32;
556                    }
557
558                    ethernet::eth_write32(
559                        &mut borrow.device,
560                        read32,
561                        write32,
562                        command_q_addr,
563                        EthCommCoord {
564                            coord: op.addr,
565                            noc_id,
566                            noc_x: x as u8,
567                            noc_y: y as u8,
568                            offset: addr,
569                        },
570                        std::time::Duration::from_secs(5 * 60),
571                        value,
572                    )?;
573                } else {
574                    ethernet::block_write(
575                        &mut borrow.device,
576                        read32,
577                        write32,
578                        dma_buffer,
579                        command_q_addr,
580                        std::time::Duration::from_secs(5 * 60),
581                        fake_block,
582                        EthCommCoord {
583                            coord: op.addr,
584                            noc_id,
585                            noc_x: x as u8,
586                            noc_y: y as u8,
587                            offset: addr,
588                        },
589                        unsafe { std::slice::from_raw_parts(data, len as usize) },
590                    )?;
591                }
592            }
593            luwen_if::FnNoc::Broadcast {
594                noc_id,
595                addr,
596                data,
597                len,
598            } => {
599                todo!("Tried to do an ethernet broadcast which is not supported, noc_id: {}, addr: {:#x}, data: {:p}, len: {:x}", noc_id, addr, data, len);
600            }
601        },
602    }
603
604    Ok(())
605}
606
607pub fn open(interface_id: usize) -> Result<luwen_if::chip::Chip, LuwenError> {
608    let ud = ExtendedPciDevice::open(interface_id)?;
609
610    let arch = ud.borrow().device.arch;
611
612    Ok(luwen_if::chip::Chip::open(
613        arch,
614        luwen_if::CallbackStorage::new(comms_callback, ud.clone()),
615    )?)
616}