Skip to main content

simple_ahci/
ahci.rs

1use alloc::alloc::alloc_zeroed;
2use core::{alloc::Layout, marker::PhantomData, ptr::NonNull};
3
4use log::{debug, error, info, warn};
5use volatile::VolatilePtr;
6
7use crate::{
8    Hal,
9    ata::{
10        ATA_CMD_ID_ATA, ATA_CMD_READ, ATA_CMD_READ_EXT, ATA_CMD_WRITE, ATA_CMD_WRITE_EXT,
11        ATA_ID_FW_REV, ATA_ID_FW_REV_LEN, ATA_ID_PROD, ATA_ID_PROD_LEN, ATA_ID_SERNO,
12        ATA_ID_SERNO_LEN, ATA_ID_WORDS, SATA_FIS_TYPE_REGISTER_H2D, ata_id_has_lba48,
13        ata_id_n_sectors, ata_id_to_string,
14    },
15    hal::wait_until_timeout,
16    mmio::{
17        AhciMmio, AhciMmioVolatileFieldAccess, CAP, GenericHostControlVolatileFieldAccess, ICC,
18        PortRegisters, PortRegistersVolatileFieldAccess, PxCMD, PxI,
19    },
20    types::{
21        AHCI_MAX_BYTES_PER_CMD, AHCI_MAX_BYTES_PER_SG, AHCI_MAX_SG, ahci_cmd_hdr, ahci_cmd_list,
22        ahci_cmd_tbl, ahci_cmd_tblVolatileFieldAccess, ahci_rx_fis, ahci_sg, sata_fis_h2d,
23    },
24};
25
26fn alloc<T: Sized>(align: usize) -> VolatilePtr<'static, T> {
27    unsafe {
28        VolatilePtr::new(NonNull::new_unchecked(
29            alloc_zeroed(Layout::from_size_align(size_of::<T>(), align).unwrap()).cast(),
30        ))
31    }
32}
33
34struct AhciPort<H> {
35    port: VolatilePtr<'static, PortRegisters>,
36
37    cmd_list: VolatilePtr<'static, ahci_cmd_list>,
38    #[allow(dead_code)]
39    fis: VolatilePtr<'static, ahci_rx_fis>,
40    cmd_tbl: VolatilePtr<'static, ahci_cmd_tbl>,
41
42    _h: PhantomData<H>,
43}
44
45impl<H: Hal> AhciPort<H> {
46    fn try_new(host: &VolatilePtr<'static, AhciMmio>, i: u8) -> Option<Self> {
47        let port = unsafe {
48            host.ports()
49                .map(|ports| ports.cast::<PortRegisters>().add(i as usize))
50        };
51
52        // 1. Stop the port (ST=0, FRE=0)
53        port.CMD().update(|cmd| cmd.with_ST(false).with_FRE(false));
54
55        // Wait for CR and FR to clear
56        if !wait_until_timeout::<H>(|| !port.CMD().read().CR(), 500) {
57            warn!("Port {i} stop engine timeout (CR)");
58        }
59        if !wait_until_timeout::<H>(|| !port.CMD().read().FR(), 500) {
60            warn!("Port {i} stop FIS receive timeout (FR)");
61        }
62
63        // 2. Check if device is busy (BSY or DRQ) and try CLO
64        let tfd = port.TFD().read();
65        if tfd.STS_BSY() || tfd.STS_DRQ() {
66            debug!("Port {i} busy (TFD: {tfd:?}), trying CLO");
67            let cap = host.host().cap().read();
68            if cap.SCLO() {
69                port.CMD().update(|cmd| cmd.with_CLO(true));
70                if !wait_until_timeout::<H>(|| !port.CMD().read().CLO(), 1000) {
71                    warn!("Port {i} CLO timeout");
72                }
73            }
74        }
75
76        // 3. Spin up
77        port.CMD().update(|cmd| cmd.with_SUD(true));
78        if !wait_until_timeout::<H>(|| port.CMD().read().SUD(), 1000) {
79            warn!("Port {i} set Spin-Up Device timeout");
80            return None;
81        }
82
83        // 4. Wait for Link Up
84        if !wait_until_timeout::<H>(
85            || {
86                let det = port.SSTS().read().DET();
87                det == 0x1 || det == 0x3
88            },
89            1000,
90        ) {
91            warn!("Port {i} sata link timeout");
92            return None;
93        }
94        debug!("Port {i} sata link up");
95
96        // 5. Clear Errors
97        port.SERR().write(port.SERR().read());
98        port.IS().write(port.IS().read());
99
100        // 6. Enable Interrupts
101        port.IE().write(PxI::default_enable().with_DP(true));
102
103        host.host().is().write(1 << i);
104
105        if port.SSTS().read().DET() != 3 {
106            // Try to wait a bit more if it is 1
107            if !wait_until_timeout::<H>(|| port.SSTS().read().DET() == 3, 1000) {
108                warn!(
109                    "Port {i} physical link not established (DET={})",
110                    port.SSTS().read().DET()
111                );
112                return None;
113            }
114        }
115
116        let cmd_list = alloc::<ahci_cmd_list>(1024);
117        let cmd_list_addr = H::virt_to_phys(cmd_list.as_raw_ptr().addr().get());
118        debug!(
119            "Port {i} cmd_list va={:#x} pa={:#x}",
120            cmd_list.as_raw_ptr().addr().get(),
121            cmd_list_addr
122        );
123        port.CLB().write(cmd_list_addr as u32);
124        port.CLBU().write((cmd_list_addr >> 32) as u32);
125
126        let fis = alloc::<ahci_rx_fis>(256);
127        let fis_addr = H::virt_to_phys(fis.as_raw_ptr().addr().get());
128        debug!(
129            "Port {i} fis va={:#x} pa={:#x}",
130            fis.as_raw_ptr().addr().get(),
131            fis_addr
132        );
133        port.FB().write(fis_addr as u32);
134        port.FBU().write((fis_addr >> 32) as u32);
135
136        let cmd_tbl = alloc::<ahci_cmd_tbl>(128);
137        debug!(
138            "Port {i} cmd_tbl va={:#x} pa={:#x}",
139            cmd_tbl.as_raw_ptr().addr().get(),
140            H::virt_to_phys(cmd_tbl.as_raw_ptr().addr().get())
141        );
142
143        // Note: We used to check for BSY/DRQ here, but some devices (like QEMU)
144        // might be busy after spin-up/link-up. The original driver for reference
145        // proceeds to start the port without waiting for BSY to clear here.
146        // It waits for BSY to clear *after* setting the start bits.
147
148        port.CMD().write(
149            PxCMD::new()
150                .with_ICC(ICC::Active)
151                .with_FRE(true)
152                .with_POD(true)
153                .with_SUD(true)
154                .with_ST(true),
155        );
156
157        if !wait_until_timeout::<H>(
158            || {
159                let tfd = port.TFD().read();
160                if tfd.STS_ERR() {
161                    // warn!("Port {i} error after start (TFD: {:?})", tfd);
162                }
163                !(tfd.STS_ERR() | tfd.STS_DRQ() | tfd.STS_BSY())
164            },
165            1000,   //try not to wait too long
166        ) {
167            warn!("Port {i} start timeout (TFD: {:?})", port.TFD().read());
168            return None;
169        }
170
171        Some(Self {
172            port,
173            cmd_list,
174            fis,
175            cmd_tbl,
176            _h: PhantomData,
177        })
178    }
179
180    fn exec_cmd(&mut self, cfis: sata_fis_h2d, buf: *mut [u8], is_write: bool) -> bool {
181        // Always use slot 0 for simplicity (like reference driver)
182        let slot: u32 = 0;
183
184        // Wait for slot 0 to be free
185        if !wait_until_timeout::<H>(|| self.port.CI().read() & 1 == 0, 1000) {
186            error!("Slot 0 busy timeout");
187            return false;
188        }
189
190        if buf.len() > AHCI_MAX_BYTES_PER_CMD {
191            error!("Exceeding max transfer data limit");
192            return false;
193        }
194
195        // Write command FIS to command table
196        self.cmd_tbl.hdr().write(cfis);
197
198        let sg_cnt = if !buf.is_null() && !buf.is_empty() {
199            let sg_cnt = ((buf.len() - 1) / AHCI_MAX_BYTES_PER_SG) + 1;
200            if sg_cnt > AHCI_MAX_SG {
201                error!("Exceeding max sg limit");
202                return false;
203            }
204
205            let mut remaining = buf.len();
206            for i in 0..sg_cnt {
207                let offset = i * AHCI_MAX_BYTES_PER_SG;
208                let len = remaining.min(AHCI_MAX_BYTES_PER_SG);
209
210                let buf_addr = H::virt_to_phys(unsafe { (buf as *mut u8).add(offset).addr() });
211                let sg = unsafe { &mut self.cmd_tbl.sgs().map(|sg| sg.cast::<ahci_sg>().add(i)) };
212                sg.write(ahci_sg {
213                    addr_lo: buf_addr as u32,
214                    addr_hi: (buf_addr >> 32) as u32,
215                    flags_size: (len - 1) as u32 & 0x3fffff, // DBC: Data Byte Count (0-based)
216                    ..Default::default()
217                });
218
219                remaining -= len;
220            }
221
222            sg_cnt
223        } else {
224            0
225        };
226
227        // Build command header options:
228        // Bits 0-4: Command FIS length in DWORDs (5 for sata_fis_h2d which is 20 bytes
229        // = 5 DWORDs) Bit 6: Write (1) or Read (0)
230        // Bits 16-31: PRDTL (Physical Region Descriptor Table Length)
231        let cfl = size_of::<sata_fis_h2d>() / 4; // 20 / 4 = 5
232        let opts = (cfl as u32) | ((sg_cnt as u32) << 16) | ((is_write as u32) << 6);
233
234        let cmd_tbl_addr = H::virt_to_phys(self.cmd_tbl.as_raw_ptr().addr().get());
235
236        debug!(
237            "exec_cmd: slot={} opts={:#x} cmd_tbl_addr={:#x} sg_cnt={} buf_len={}",
238            slot,
239            opts,
240            cmd_tbl_addr,
241            sg_cnt,
242            buf.len()
243        );
244
245        // Write command header to slot 0
246        unsafe {
247            self.cmd_list
248                .map(|list| list.cast::<ahci_cmd_hdr>().add(slot as usize))
249        }
250        .write(ahci_cmd_hdr {
251            opts,
252            status: 0,
253            tbl_addr_lo: cmd_tbl_addr as u32,
254            tbl_addr_hi: (cmd_tbl_addr >> 32) as u32,
255            reserved: [0; 4],
256        });
257
258        H::flush_dcache();
259
260        // Issue command
261        self.port.CI().write(1 << slot);
262
263        // Wait for completion
264        if !wait_until_timeout::<H>(|| self.port.CI().read() & (1 << slot) == 0, 1000) {
265            let is = self.port.IS().read();
266            let tfd = self.port.TFD().read();
267            error!(
268                "AHCI command timeout: CI={:#x} IS={:?} TFD={:?}",
269                self.port.CI().read(),
270                is,
271                tfd
272            );
273            return false;
274        }
275
276        H::flush_dcache();
277        true
278    }
279}
280
281pub struct AhciDriver<H> {
282    #[allow(dead_code)]
283    mmio: VolatilePtr<'static, AhciMmio>,
284    port: AhciPort<H>,
285
286    block_size: usize,
287    max_lba: u64,
288    is_lba48: bool,
289
290    _h: PhantomData<H>,
291}
292
293/// Safety:
294/// - `Send`: The driver takes ownership of the MMIO region and can be safely moved between threads.
295/// - `Sync`: The driver's mutating operations require `&mut self`, ensuring exclusive access.
296///   Read-only operations (like getting block size) are safe to perform concurrently.
297unsafe impl<H: Hal> Send for AhciDriver<H> {}
298unsafe impl<H: Hal> Sync for AhciDriver<H> {}
299
300impl<H: Hal> AhciDriver<H> {
301    /// Try to construct a new AHCI driver from the given MMIO base address.
302    ///
303    /// # Safety
304    ///
305    /// The caller must ensure that:
306    /// - `base` is a valid virtual address pointing to the AHCI controller's MMIO register block.
307    /// - The memory region starting at `base` is properly mapped and accessible.
308    /// - No other code is concurrently accessing the same AHCI controller.
309    /// - The AHCI controller hardware is present and functional at the given address.
310    pub unsafe fn try_new(base: usize) -> Option<Self> {
311        // SAFETY: The caller guarantees `base` is a valid AHCI MMIO base address.
312        let mmio = unsafe { VolatilePtr::new(NonNull::new(base as *mut _).unwrap()) };
313        let host = mmio.host();
314
315        // reset ahci controller
316        host.ghc().update(|mut ghc| {
317            if !ghc.HR() {
318                ghc.set_HR(true);
319            }
320            ghc
321        });
322        if !wait_until_timeout::<H>(|| !host.ghc().read().HR(), 1000) {
323            error!("AHCI HBA reset timeout");
324            return None;
325        }
326
327        // enable ahci
328        host.ghc().update(|ghc| ghc.with_AE(true));
329        wait_until_timeout::<H>(|| false, 1);
330
331        // init cap and pi
332        host.cap().write(CAP::new().with_SMPS(true).with_SSS(true));
333        host.pi().write(0xf);
334
335        let vs = host.vs().read();
336        info!("AHCI ver {vs}");
337
338        let cap = host.cap().read();
339        info!("AHCI cap {cap}");
340
341        let cap2 = host.cap2().read();
342        info!("AHCI cap2 {cap2:?}");
343
344        let pi = host.pi().read();
345        info!("AHCI ports implemented {pi}");
346
347        host.ghc().update(|ghc| ghc.with_IE(true));
348
349        let mut port = None;
350        for i in 0..cap.NP() + 1 {
351            if let Some(p) = AhciPort::<H>::try_new(&mmio, i) {
352                port = Some(p);
353            }
354        }
355
356        let Some(mut port) = port else {
357            error!("No AHCI ports initialized");
358            return None;
359        };
360
361        let mut id = [0u16; ATA_ID_WORDS];
362        port.exec_cmd(
363            sata_fis_h2d {
364                fis_type: SATA_FIS_TYPE_REGISTER_H2D,
365                pm_port_c: 0x80,
366                command: ATA_CMD_ID_ATA,
367                ..Default::default()
368            },
369            unsafe {
370                core::slice::from_raw_parts_mut(id.as_mut_ptr().cast::<u8>(), size_of_val(&id))
371            },
372            false,
373        );
374
375        let product = ata_id_to_string(&id, ATA_ID_PROD, ATA_ID_PROD_LEN);
376        let serial = ata_id_to_string(&id, ATA_ID_SERNO, ATA_ID_SERNO_LEN);
377        let rev = ata_id_to_string(&id, ATA_ID_FW_REV, ATA_ID_FW_REV_LEN);
378
379        info!("AHCI device: {product} {serial} {rev}");
380
381        let max_lba = ata_id_n_sectors(&id);
382        let is_lba48 = ata_id_has_lba48(&id);
383        let block_size = 512;
384
385        Some(Self {
386            mmio,
387            port,
388            block_size,
389            max_lba,
390            is_lba48,
391            _h: PhantomData,
392        })
393    }
394
395    pub fn capacity(&self) -> u64 {
396        self.max_lba
397    }
398
399    pub fn block_size(&self) -> usize {
400        self.block_size
401    }
402
403    pub fn read(&mut self, block_id: u64, buf: &mut [u8]) -> bool {
404        self.rw_common(block_id, buf, false)
405    }
406
407    pub fn write(&mut self, block_id: u64, buf: &[u8]) -> bool {
408        // Cast to mut ptr for internal handling, but we won't modify it if it's write
409        let buf_mut =
410            unsafe { core::slice::from_raw_parts_mut(buf.as_ptr() as *mut u8, buf.len()) };
411        self.rw_common(block_id, buf_mut, true)
412    }
413
414    fn rw_common(&mut self, block_id: u64, buf: &mut [u8], is_write: bool) -> bool {
415        let mut start = block_id;
416        let mut remaining_bytes = buf.len();
417        let mut buf_offset = 0;
418
419        while remaining_bytes > 0 {
420            let sectors = remaining_bytes.div_ceil(self.block_size);
421            let max_sectors = if self.is_lba48 { 65536 } else { 256 };
422            let count = sectors.min(max_sectors);
423            let byte_count = count * self.block_size;
424            let current_bytes = byte_count.min(remaining_bytes);
425
426            // Construct FIS
427            let mut fis = sata_fis_h2d {
428                fis_type: SATA_FIS_TYPE_REGISTER_H2D,
429                pm_port_c: 0x80,
430                ..Default::default()
431            };
432
433            if self.is_lba48 {
434                fis.command = if is_write {
435                    ATA_CMD_WRITE_EXT
436                } else {
437                    ATA_CMD_READ_EXT
438                };
439                fis.lba_low = start as u8;
440                fis.lba_mid = (start >> 8) as u8;
441                fis.lba_high = (start >> 16) as u8;
442                fis.lba_low_exp = (start >> 24) as u8;
443                fis.lba_mid_exp = (start >> 32) as u8;
444                fis.lba_high_exp = (start >> 40) as u8;
445                fis.device = 0x40; // LBA mode
446                fis.sector_count = (count & 0xff) as u8;
447                fis.sector_count_exp = ((count >> 8) & 0xff) as u8;
448            } else {
449                fis.command = if is_write {
450                    ATA_CMD_WRITE
451                } else {
452                    ATA_CMD_READ
453                };
454                fis.lba_low = start as u8;
455                fis.lba_mid = (start >> 8) as u8;
456                fis.lba_high = (start >> 16) as u8;
457                fis.device = 0x40 | ((start >> 24) as u8 & 0x0f); // LBA mode + top 4 bits
458                fis.sector_count = (count & 0xff) as u8;
459            }
460
461            let slice = &mut buf[buf_offset..buf_offset + current_bytes];
462
463            // Check buffer alignment. AHCI requires data buffer to be even-byte aligned.
464            // We use 4-byte alignment to be safe.
465            if slice.as_ptr() as usize % 4 != 0 {
466                let mut temp_buf = alloc::vec![0u8; slice.len()];
467                if is_write {
468                    temp_buf.copy_from_slice(slice);
469                }
470                
471                if !self.port.exec_cmd(fis, temp_buf.as_mut_slice(), is_write) {
472                    return false;
473                }
474
475                if !is_write {
476                    slice.copy_from_slice(&temp_buf);
477                }
478            } else {
479                if !self.port.exec_cmd(fis, slice, is_write) {
480                    return false;
481                }
482            }
483
484            start += count as u64;
485            remaining_bytes -= current_bytes;
486            buf_offset += current_bytes;
487        }
488        true
489    }
490}