1use alloc::alloc::alloc_zeroed;
2use core::{alloc::Layout, marker::PhantomData, ptr::NonNull};
3
4use log::{debug, error, info, warn};
5use volatile::VolatilePtr;
6
7use crate::{
8 Hal,
9 ata::{
10 ATA_CMD_ID_ATA, ATA_CMD_READ, ATA_CMD_READ_EXT, ATA_CMD_WRITE, ATA_CMD_WRITE_EXT,
11 ATA_ID_FW_REV, ATA_ID_FW_REV_LEN, ATA_ID_PROD, ATA_ID_PROD_LEN, ATA_ID_SERNO,
12 ATA_ID_SERNO_LEN, ATA_ID_WORDS, SATA_FIS_TYPE_REGISTER_H2D, ata_id_has_lba48,
13 ata_id_n_sectors, ata_id_to_string,
14 },
15 hal::wait_until_timeout,
16 mmio::{
17 AhciMmio, AhciMmioVolatileFieldAccess, CAP, GenericHostControlVolatileFieldAccess, ICC,
18 PortRegisters, PortRegistersVolatileFieldAccess, PxCMD, PxI,
19 },
20 types::{
21 AHCI_MAX_BYTES_PER_CMD, AHCI_MAX_BYTES_PER_SG, AHCI_MAX_SG, ahci_cmd_hdr, ahci_cmd_list,
22 ahci_cmd_tbl, ahci_cmd_tblVolatileFieldAccess, ahci_rx_fis, ahci_sg, sata_fis_h2d,
23 },
24};
25
26fn alloc<T: Sized>(align: usize) -> VolatilePtr<'static, T> {
27 unsafe {
28 VolatilePtr::new(NonNull::new_unchecked(
29 alloc_zeroed(Layout::from_size_align(size_of::<T>(), align).unwrap()).cast(),
30 ))
31 }
32}
33
34struct AhciPort<H> {
35 port: VolatilePtr<'static, PortRegisters>,
36
37 cmd_list: VolatilePtr<'static, ahci_cmd_list>,
38 #[allow(dead_code)]
39 fis: VolatilePtr<'static, ahci_rx_fis>,
40 cmd_tbl: VolatilePtr<'static, ahci_cmd_tbl>,
41
42 _h: PhantomData<H>,
43}
44
45impl<H: Hal> AhciPort<H> {
46 fn try_new(host: &VolatilePtr<'static, AhciMmio>, i: u8) -> Option<Self> {
47 let port = unsafe {
48 host.ports()
49 .map(|ports| ports.cast::<PortRegisters>().add(i as usize))
50 };
51
52 port.CMD().update(|cmd| cmd.with_ST(false).with_FRE(false));
54
55 if !wait_until_timeout::<H>(|| !port.CMD().read().CR(), 500) {
57 warn!("Port {i} stop engine timeout (CR)");
58 }
59 if !wait_until_timeout::<H>(|| !port.CMD().read().FR(), 500) {
60 warn!("Port {i} stop FIS receive timeout (FR)");
61 }
62
63 let tfd = port.TFD().read();
65 if tfd.STS_BSY() || tfd.STS_DRQ() {
66 debug!("Port {i} busy (TFD: {tfd:?}), trying CLO");
67 let cap = host.host().cap().read();
68 if cap.SCLO() {
69 port.CMD().update(|cmd| cmd.with_CLO(true));
70 if !wait_until_timeout::<H>(|| !port.CMD().read().CLO(), 1000) {
71 warn!("Port {i} CLO timeout");
72 }
73 }
74 }
75
76 port.CMD().update(|cmd| cmd.with_SUD(true));
78 if !wait_until_timeout::<H>(|| port.CMD().read().SUD(), 1000) {
79 warn!("Port {i} set Spin-Up Device timeout");
80 return None;
81 }
82
83 if !wait_until_timeout::<H>(
85 || {
86 let det = port.SSTS().read().DET();
87 det == 0x1 || det == 0x3
88 },
89 1000,
90 ) {
91 warn!("Port {i} sata link timeout");
92 return None;
93 }
94 debug!("Port {i} sata link up");
95
96 port.SERR().write(port.SERR().read());
98 port.IS().write(port.IS().read());
99
100 port.IE().write(PxI::default_enable().with_DP(true));
102
103 host.host().is().write(1 << i);
104
105 if port.SSTS().read().DET() != 3 {
106 if !wait_until_timeout::<H>(|| port.SSTS().read().DET() == 3, 1000) {
108 warn!(
109 "Port {i} physical link not established (DET={})",
110 port.SSTS().read().DET()
111 );
112 return None;
113 }
114 }
115
116 let cmd_list = alloc::<ahci_cmd_list>(1024);
117 let cmd_list_addr = H::virt_to_phys(cmd_list.as_raw_ptr().addr().get());
118 debug!(
119 "Port {i} cmd_list va={:#x} pa={:#x}",
120 cmd_list.as_raw_ptr().addr().get(),
121 cmd_list_addr
122 );
123 port.CLB().write(cmd_list_addr as u32);
124 port.CLBU().write((cmd_list_addr >> 32) as u32);
125
126 let fis = alloc::<ahci_rx_fis>(256);
127 let fis_addr = H::virt_to_phys(fis.as_raw_ptr().addr().get());
128 debug!(
129 "Port {i} fis va={:#x} pa={:#x}",
130 fis.as_raw_ptr().addr().get(),
131 fis_addr
132 );
133 port.FB().write(fis_addr as u32);
134 port.FBU().write((fis_addr >> 32) as u32);
135
136 let cmd_tbl = alloc::<ahci_cmd_tbl>(128);
137 debug!(
138 "Port {i} cmd_tbl va={:#x} pa={:#x}",
139 cmd_tbl.as_raw_ptr().addr().get(),
140 H::virt_to_phys(cmd_tbl.as_raw_ptr().addr().get())
141 );
142
143 port.CMD().write(
149 PxCMD::new()
150 .with_ICC(ICC::Active)
151 .with_FRE(true)
152 .with_POD(true)
153 .with_SUD(true)
154 .with_ST(true),
155 );
156
157 if !wait_until_timeout::<H>(
158 || {
159 let tfd = port.TFD().read();
160 if tfd.STS_ERR() {
161 }
163 !(tfd.STS_ERR() | tfd.STS_DRQ() | tfd.STS_BSY())
164 },
165 1000, ) {
167 warn!("Port {i} start timeout (TFD: {:?})", port.TFD().read());
168 return None;
169 }
170
171 Some(Self {
172 port,
173 cmd_list,
174 fis,
175 cmd_tbl,
176 _h: PhantomData,
177 })
178 }
179
180 fn exec_cmd(&mut self, cfis: sata_fis_h2d, buf: *mut [u8], is_write: bool) -> bool {
181 let slot: u32 = 0;
183
184 if !wait_until_timeout::<H>(|| self.port.CI().read() & 1 == 0, 1000) {
186 error!("Slot 0 busy timeout");
187 return false;
188 }
189
190 if buf.len() > AHCI_MAX_BYTES_PER_CMD {
191 error!("Exceeding max transfer data limit");
192 return false;
193 }
194
195 self.cmd_tbl.hdr().write(cfis);
197
198 let sg_cnt = if !buf.is_null() && !buf.is_empty() {
199 let sg_cnt = ((buf.len() - 1) / AHCI_MAX_BYTES_PER_SG) + 1;
200 if sg_cnt > AHCI_MAX_SG {
201 error!("Exceeding max sg limit");
202 return false;
203 }
204
205 let mut remaining = buf.len();
206 for i in 0..sg_cnt {
207 let offset = i * AHCI_MAX_BYTES_PER_SG;
208 let len = remaining.min(AHCI_MAX_BYTES_PER_SG);
209
210 let buf_addr = H::virt_to_phys(unsafe { (buf as *mut u8).add(offset).addr() });
211 let sg = unsafe { &mut self.cmd_tbl.sgs().map(|sg| sg.cast::<ahci_sg>().add(i)) };
212 sg.write(ahci_sg {
213 addr_lo: buf_addr as u32,
214 addr_hi: (buf_addr >> 32) as u32,
215 flags_size: (len - 1) as u32 & 0x3fffff, ..Default::default()
217 });
218
219 remaining -= len;
220 }
221
222 sg_cnt
223 } else {
224 0
225 };
226
227 let cfl = size_of::<sata_fis_h2d>() / 4; let opts = (cfl as u32) | ((sg_cnt as u32) << 16) | ((is_write as u32) << 6);
233
234 let cmd_tbl_addr = H::virt_to_phys(self.cmd_tbl.as_raw_ptr().addr().get());
235
236 debug!(
237 "exec_cmd: slot={} opts={:#x} cmd_tbl_addr={:#x} sg_cnt={} buf_len={}",
238 slot,
239 opts,
240 cmd_tbl_addr,
241 sg_cnt,
242 buf.len()
243 );
244
245 unsafe {
247 self.cmd_list
248 .map(|list| list.cast::<ahci_cmd_hdr>().add(slot as usize))
249 }
250 .write(ahci_cmd_hdr {
251 opts,
252 status: 0,
253 tbl_addr_lo: cmd_tbl_addr as u32,
254 tbl_addr_hi: (cmd_tbl_addr >> 32) as u32,
255 reserved: [0; 4],
256 });
257
258 H::flush_dcache();
259
260 self.port.CI().write(1 << slot);
262
263 if !wait_until_timeout::<H>(|| self.port.CI().read() & (1 << slot) == 0, 1000) {
265 let is = self.port.IS().read();
266 let tfd = self.port.TFD().read();
267 error!(
268 "AHCI command timeout: CI={:#x} IS={:?} TFD={:?}",
269 self.port.CI().read(),
270 is,
271 tfd
272 );
273 return false;
274 }
275
276 H::flush_dcache();
277 true
278 }
279}
280
281pub struct AhciDriver<H> {
282 #[allow(dead_code)]
283 mmio: VolatilePtr<'static, AhciMmio>,
284 port: AhciPort<H>,
285
286 block_size: usize,
287 max_lba: u64,
288 is_lba48: bool,
289
290 _h: PhantomData<H>,
291}
292
293unsafe impl<H: Hal> Send for AhciDriver<H> {}
298unsafe impl<H: Hal> Sync for AhciDriver<H> {}
299
300impl<H: Hal> AhciDriver<H> {
301 pub unsafe fn try_new(base: usize) -> Option<Self> {
311 let mmio = unsafe { VolatilePtr::new(NonNull::new(base as *mut _).unwrap()) };
313 let host = mmio.host();
314
315 host.ghc().update(|mut ghc| {
317 if !ghc.HR() {
318 ghc.set_HR(true);
319 }
320 ghc
321 });
322 if !wait_until_timeout::<H>(|| !host.ghc().read().HR(), 1000) {
323 error!("AHCI HBA reset timeout");
324 return None;
325 }
326
327 host.ghc().update(|ghc| ghc.with_AE(true));
329 wait_until_timeout::<H>(|| false, 1);
330
331 host.cap().write(CAP::new().with_SMPS(true).with_SSS(true));
333 host.pi().write(0xf);
334
335 let vs = host.vs().read();
336 info!("AHCI ver {vs}");
337
338 let cap = host.cap().read();
339 info!("AHCI cap {cap}");
340
341 let cap2 = host.cap2().read();
342 info!("AHCI cap2 {cap2:?}");
343
344 let pi = host.pi().read();
345 info!("AHCI ports implemented {pi}");
346
347 host.ghc().update(|ghc| ghc.with_IE(true));
348
349 let mut port = None;
350 for i in 0..cap.NP() + 1 {
351 if let Some(p) = AhciPort::<H>::try_new(&mmio, i) {
352 port = Some(p);
353 }
354 }
355
356 let Some(mut port) = port else {
357 error!("No AHCI ports initialized");
358 return None;
359 };
360
361 let mut id = [0u16; ATA_ID_WORDS];
362 port.exec_cmd(
363 sata_fis_h2d {
364 fis_type: SATA_FIS_TYPE_REGISTER_H2D,
365 pm_port_c: 0x80,
366 command: ATA_CMD_ID_ATA,
367 ..Default::default()
368 },
369 unsafe {
370 core::slice::from_raw_parts_mut(id.as_mut_ptr().cast::<u8>(), size_of_val(&id))
371 },
372 false,
373 );
374
375 let product = ata_id_to_string(&id, ATA_ID_PROD, ATA_ID_PROD_LEN);
376 let serial = ata_id_to_string(&id, ATA_ID_SERNO, ATA_ID_SERNO_LEN);
377 let rev = ata_id_to_string(&id, ATA_ID_FW_REV, ATA_ID_FW_REV_LEN);
378
379 info!("AHCI device: {product} {serial} {rev}");
380
381 let max_lba = ata_id_n_sectors(&id);
382 let is_lba48 = ata_id_has_lba48(&id);
383 let block_size = 512;
384
385 Some(Self {
386 mmio,
387 port,
388 block_size,
389 max_lba,
390 is_lba48,
391 _h: PhantomData,
392 })
393 }
394
395 pub fn capacity(&self) -> u64 {
396 self.max_lba
397 }
398
399 pub fn block_size(&self) -> usize {
400 self.block_size
401 }
402
403 pub fn read(&mut self, block_id: u64, buf: &mut [u8]) -> bool {
404 self.rw_common(block_id, buf, false)
405 }
406
407 pub fn write(&mut self, block_id: u64, buf: &[u8]) -> bool {
408 let buf_mut =
410 unsafe { core::slice::from_raw_parts_mut(buf.as_ptr() as *mut u8, buf.len()) };
411 self.rw_common(block_id, buf_mut, true)
412 }
413
414 fn rw_common(&mut self, block_id: u64, buf: &mut [u8], is_write: bool) -> bool {
415 let mut start = block_id;
416 let mut remaining_bytes = buf.len();
417 let mut buf_offset = 0;
418
419 while remaining_bytes > 0 {
420 let sectors = remaining_bytes.div_ceil(self.block_size);
421 let max_sectors = if self.is_lba48 { 65536 } else { 256 };
422 let count = sectors.min(max_sectors);
423 let byte_count = count * self.block_size;
424 let current_bytes = byte_count.min(remaining_bytes);
425
426 let mut fis = sata_fis_h2d {
428 fis_type: SATA_FIS_TYPE_REGISTER_H2D,
429 pm_port_c: 0x80,
430 ..Default::default()
431 };
432
433 if self.is_lba48 {
434 fis.command = if is_write {
435 ATA_CMD_WRITE_EXT
436 } else {
437 ATA_CMD_READ_EXT
438 };
439 fis.lba_low = start as u8;
440 fis.lba_mid = (start >> 8) as u8;
441 fis.lba_high = (start >> 16) as u8;
442 fis.lba_low_exp = (start >> 24) as u8;
443 fis.lba_mid_exp = (start >> 32) as u8;
444 fis.lba_high_exp = (start >> 40) as u8;
445 fis.device = 0x40; fis.sector_count = (count & 0xff) as u8;
447 fis.sector_count_exp = ((count >> 8) & 0xff) as u8;
448 } else {
449 fis.command = if is_write {
450 ATA_CMD_WRITE
451 } else {
452 ATA_CMD_READ
453 };
454 fis.lba_low = start as u8;
455 fis.lba_mid = (start >> 8) as u8;
456 fis.lba_high = (start >> 16) as u8;
457 fis.device = 0x40 | ((start >> 24) as u8 & 0x0f); fis.sector_count = (count & 0xff) as u8;
459 }
460
461 let slice = &mut buf[buf_offset..buf_offset + current_bytes];
462
463 if slice.as_ptr() as usize % 4 != 0 {
466 let mut temp_buf = alloc::vec![0u8; slice.len()];
467 if is_write {
468 temp_buf.copy_from_slice(slice);
469 }
470
471 if !self.port.exec_cmd(fis, temp_buf.as_mut_slice(), is_write) {
472 return false;
473 }
474
475 if !is_write {
476 slice.copy_from_slice(&temp_buf);
477 }
478 } else {
479 if !self.port.exec_cmd(fis, slice, is_write) {
480 return false;
481 }
482 }
483
484 start += count as u64;
485 remaining_bytes -= current_bytes;
486 buf_offset += current_bytes;
487 }
488 true
489 }
490}