fdt_edit/
fdt.rs

1use alloc::{
2    collections::BTreeMap,
3    format,
4    string::{String, ToString},
5    vec::Vec,
6};
7
8pub use fdt_raw::MemoryReservation;
9use fdt_raw::{FdtError, Phandle, Status};
10
11use crate::{
12    ClockType, Node, NodeIter, NodeIterMut, NodeKind, NodeMut, NodeRef,
13    encode::{FdtData, FdtEncoder},
14};
15
16/// 可编辑的 FDT
17#[derive(Clone)]
18pub struct Fdt {
19    /// 引导 CPU ID
20    pub boot_cpuid_phys: u32,
21    /// 内存保留块
22    pub memory_reservations: Vec<MemoryReservation>,
23    /// 根节点
24    pub root: Node,
25    /// phandle 到节点完整路径的缓存
26    phandle_cache: BTreeMap<Phandle, String>,
27}
28
29impl Default for Fdt {
30    fn default() -> Self {
31        Self::new()
32    }
33}
34
35impl Fdt {
36    /// 创建新的空 FDT
37    pub fn new() -> Self {
38        Self {
39            boot_cpuid_phys: 0,
40            memory_reservations: Vec::new(),
41            root: Node::new(""),
42            phandle_cache: BTreeMap::new(),
43        }
44    }
45
46    /// 从原始 FDT 数据解析
47    pub fn from_bytes(data: &[u8]) -> Result<Self, FdtError> {
48        let raw_fdt = fdt_raw::Fdt::from_bytes(data)?;
49        Self::from_raw(&raw_fdt)
50    }
51
52    /// 从原始指针解析
53    ///
54    /// # Safety
55    /// 调用者必须确保指针有效且指向有效的 FDT 数据
56    pub unsafe fn from_ptr(ptr: *mut u8) -> Result<Self, FdtError> {
57        let raw_fdt = unsafe { fdt_raw::Fdt::from_ptr(ptr)? };
58        Self::from_raw(&raw_fdt)
59    }
60
61    /// 从 fdt_raw::Fdt 转换
62    fn from_raw(raw_fdt: &fdt_raw::Fdt) -> Result<Self, FdtError> {
63        let header = raw_fdt.header();
64
65        let mut fdt = Fdt {
66            boot_cpuid_phys: header.boot_cpuid_phys,
67            memory_reservations: raw_fdt.memory_reservations().collect(),
68            root: Node::new(""),
69            phandle_cache: BTreeMap::new(),
70        };
71
72        // 构建节点树
73        // 使用栈来跟踪父节点,栈底是一个虚拟父节点
74        let mut node_stack: Vec<Node> = Vec::new();
75
76        for raw_node in raw_fdt.all_nodes() {
77            let level = raw_node.level();
78            let node = Node::from(&raw_node);
79
80            // 弹出栈直到达到正确的父级别
81            // level 0 = 根节点,应该直接放入空栈
82            // level 1 = 根节点的子节点,栈中应该只有根节点
83            while node_stack.len() > level {
84                let child = node_stack.pop().unwrap();
85                if let Some(parent) = node_stack.last_mut() {
86                    parent.add_child(child);
87                } else {
88                    // 这是根节点
89                    fdt.root = child;
90                }
91            }
92
93            node_stack.push(node);
94        }
95
96        // 弹出所有剩余节点
97        while let Some(child) = node_stack.pop() {
98            if let Some(parent) = node_stack.last_mut() {
99                parent.add_child(child);
100            } else {
101                // 这是根节点
102                fdt.root = child;
103            }
104        }
105
106        // 构建 phandle 缓存
107        fdt.rebuild_phandle_cache();
108
109        Ok(fdt)
110    }
111
112    /// 重建 phandle 缓存
113    pub fn rebuild_phandle_cache(&mut self) {
114        self.phandle_cache.clear();
115        let root_clone = self.root.clone();
116        self.build_phandle_cache_recursive(&root_clone, "/");
117    }
118
119    /// 递归构建 phandle 缓存
120    fn build_phandle_cache_recursive(&mut self, node: &Node, current_path: &str) {
121        // 检查节点是否有 phandle 属性
122        if let Some(phandle) = node.phandle() {
123            self.phandle_cache.insert(phandle, current_path.to_string());
124        }
125
126        // 递归处理子节点
127        for child in &node.children {
128            let child_name = child.name();
129            let child_path = if current_path == "/" {
130                format!("/{}", child_name)
131            } else {
132                format!("{}/{}", current_path, child_name)
133            };
134            self.build_phandle_cache_recursive(child, &child_path);
135        }
136    }
137
138    /// 规范化路径:如果是别名则解析为完整路径,否则确保以 / 开头
139    fn normalize_path(&self, path: &str) -> Option<String> {
140        if path.starts_with('/') {
141            Some(path.to_string())
142        } else {
143            // 尝试解析别名
144            self.resolve_alias(path).map(|s| s.to_string())
145        }
146    }
147
148    /// 解析别名,返回对应的完整路径
149    ///
150    /// 从 /aliases 节点查找别名对应的路径
151    pub fn resolve_alias(&self, alias: &str) -> Option<&str> {
152        let aliases_node = self.get_by_path("/aliases")?;
153        let prop = aliases_node.find_property(alias)?;
154        prop.as_str()
155    }
156
157    /// 获取所有别名
158    ///
159    /// 返回 (别名, 路径) 的列表
160    pub fn aliases(&self) -> Vec<(String, String)> {
161        let mut result = Vec::new();
162        if let Some(aliases_node) = self.get_by_path("/aliases") {
163            for prop in aliases_node.properties() {
164                let name = prop.name().to_string();
165                let path = prop.as_str().unwrap().to_string();
166                result.push((name, path));
167            }
168        }
169        result
170    }
171
172    /// 根据 phandle 查找节点
173    /// 返回 (节点引用, 完整路径)
174    pub fn find_by_phandle(&self, phandle: Phandle) -> Option<NodeRef<'_>> {
175        let path = self.phandle_cache.get(&phandle)?.clone();
176        self.get_by_path(&path)
177    }
178
179    /// 根据 phandle 查找节点(可变)
180    /// 返回 (节点可变引用, 完整路径)
181    pub fn find_by_phandle_mut(&mut self, phandle: Phandle) -> Option<NodeMut<'_>> {
182        let path = self.phandle_cache.get(&phandle)?.clone();
183        self.get_by_path_mut(&path)
184    }
185
186    /// 获取根节点
187    pub fn root<'a>(&'a self) -> NodeRef<'a> {
188        self.get_by_path("/").unwrap()
189    }
190
191    /// 获取根节点(可变)
192    pub fn root_mut<'a>(&'a mut self) -> NodeMut<'a> {
193        self.get_by_path_mut("/").unwrap()
194    }
195
196    /// 应用设备树覆盖 (Device Tree Overlay)
197    ///
198    /// 支持两种 overlay 格式:
199    /// 1. fragment 格式:包含 fragment@N 节点,每个 fragment 有 target/target-path 和 __overlay__
200    /// 2. 简单格式:直接包含 __overlay__ 节点
201    ///
202    /// # 示例
203    /// ```ignore
204    /// // fragment 格式
205    /// fragment@0 {
206    ///     target-path = "/soc";
207    ///     __overlay__ {
208    ///         new_node { ... };
209    ///     };
210    /// };
211    /// ```
212    pub fn apply_overlay(&mut self, overlay: &Fdt) -> Result<(), FdtError> {
213        // 遍历 overlay 根节点的所有子节点
214        for child in &overlay.root.children {
215            if child.name().starts_with("fragment@") || child.name() == "fragment" {
216                // fragment 格式
217                self.apply_fragment(child)?;
218            } else if child.name() == "__overlay__" {
219                // 简单格式:直接应用到根节点
220                self.merge_overlay_to_root(child)?;
221            } else if child.name() == "__symbols__"
222                || child.name() == "__fixups__"
223                || child.name() == "__local_fixups__"
224            {
225                // 跳过这些特殊节点
226                continue;
227            }
228        }
229
230        // 重建 phandle 缓存
231        self.rebuild_phandle_cache();
232
233        Ok(())
234    }
235
236    /// 应用单个 fragment
237    fn apply_fragment(&mut self, fragment: &Node) -> Result<(), FdtError> {
238        // 获取目标路径
239        let target_path = self.resolve_fragment_target(fragment)?;
240
241        // 找到 __overlay__ 子节点
242        let overlay_node = fragment
243            .get_child("__overlay__")
244            .ok_or(FdtError::NotFound)?;
245
246        // 找到目标节点并应用覆盖
247        // 需要克隆路径因为后面要修改 self
248        let target_path_owned = target_path.to_string();
249
250        // 应用覆盖到目标节点
251        self.apply_overlay_to_target(&target_path_owned, overlay_node)?;
252
253        Ok(())
254    }
255
256    /// 解析 fragment 的目标路径
257    fn resolve_fragment_target(&self, fragment: &Node) -> Result<String, FdtError> {
258        // 优先使用 target-path(字符串路径)
259        if let Some(prop) = fragment.get_property("target-path") {
260            return Ok(prop.as_str().ok_or(FdtError::Utf8Parse)?.to_string());
261        }
262
263        // 使用 target(phandle 引用)
264        if let Some(prop) = fragment.get_property("target") {
265            let ph = prop.get_u32().ok_or(FdtError::InvalidInput)?;
266            let ph = Phandle::from(ph);
267
268            // 通过 phandle 找到节点,然后构建路径
269            if let Some(node) = self.find_by_phandle(ph) {
270                return Ok(node.path());
271            }
272        }
273
274        Err(FdtError::NotFound)
275    }
276
277    /// 将 overlay 应用到目标节点
278    fn apply_overlay_to_target(
279        &mut self,
280        target_path: &str,
281        overlay_node: &Node,
282    ) -> Result<(), FdtError> {
283        // 找到目标节点
284        let mut target = self
285            .get_by_path_mut(target_path)
286            .ok_or(FdtError::NotFound)?;
287
288        // 合并 overlay 的属性和子节点
289        Self::merge_nodes(target.node, overlay_node);
290
291        Ok(())
292    }
293
294    /// 合并 overlay 节点到根节点
295    fn merge_overlay_to_root(&mut self, overlay: &Node) -> Result<(), FdtError> {
296        // 合并属性和子节点到根节点
297        for prop in overlay.properties() {
298            self.root.set_property(prop.clone());
299        }
300
301        for child in overlay.children() {
302            let child_name = child.name();
303            if let Some(existing) = self.root.get_child_mut(child_name) {
304                // 合并到现有子节点
305                Self::merge_nodes(existing, child);
306            } else {
307                // 添加新子节点
308                self.root.add_child(child.clone());
309            }
310        }
311
312        Ok(())
313    }
314
315    /// 递归合并两个节点
316    fn merge_nodes(target: &mut Node, source: &Node) {
317        // 合并属性(source 覆盖 target)
318        for prop in source.properties() {
319            target.set_property(prop.clone());
320        }
321
322        // 合并子节点
323        for source_child in source.children() {
324            let child_name = &source_child.name();
325            if let Some(target_child) = target.get_child_mut(child_name) {
326                // 递归合并
327                Self::merge_nodes(target_child, source_child);
328            } else {
329                // 添加新子节点
330                target.add_child(source_child.clone());
331            }
332        }
333    }
334
335    /// 删除节点(通过设置 status = "disabled" 或直接删除)
336    ///
337    /// 如果 overlay 中的节点有 status = "disabled",则禁用目标节点
338    pub fn apply_overlay_with_delete(
339        &mut self,
340        overlay: &Fdt,
341        delete_disabled: bool,
342    ) -> Result<(), FdtError> {
343        self.apply_overlay(overlay)?;
344
345        if delete_disabled {
346            // 移除所有 status = "disabled" 的节点
347            Self::remove_disabled_nodes(&mut self.root);
348            self.rebuild_phandle_cache();
349        }
350
351        Ok(())
352    }
353
354    /// 递归移除 disabled 的节点
355    fn remove_disabled_nodes(node: &mut Node) {
356        // 移除 disabled 的子节点
357        let mut to_remove = Vec::new();
358        for child in node.children() {
359            if matches!(child.status(), Some(Status::Disabled)) {
360                to_remove.push(child.name().to_string());
361            }
362        }
363
364        for child_name in to_remove {
365            node.remove_child(&child_name);
366        }
367
368        // 递归处理剩余子节点
369        for child in node.children_mut() {
370            Self::remove_disabled_nodes(child);
371        }
372    }
373
374    /// 通过精确路径删除节点及其子树
375    /// 只支持精确路径匹配,不支持模糊匹配
376    /// 支持通过别名删除节点,并自动删除对应的别名条目
377    ///
378    /// # 参数
379    /// - `path`: 删除路径,格式如 "soc/gpio@1000" 或 "/soc/gpio@1000" 或别名
380    ///
381    /// # 返回值
382    /// `Ok(Option<Node>)`: 如果找到并删除了节点,返回被删除的节点;如果路径不存在,返回 None
383    /// `Err(FdtError)`: 如果路径格式无效
384    ///
385    /// # 示例
386    /// ```rust
387    /// # use fdt_edit::{Fdt, Node};
388    /// let mut fdt = Fdt::new();
389    ///
390    /// // 先添加节点再删除
391    /// let mut soc = Node::new("soc");
392    /// soc.add_child(Node::new("gpio@1000"));
393    /// fdt.root.add_child(soc);
394    ///
395    /// // 精确删除节点(使用完整路径)
396    /// let removed = fdt.remove_node("/soc/gpio@1000")?;
397    /// assert!(removed.is_some());
398    ///
399    /// // 尝试删除不存在的节点会返回错误
400    /// let not_found = fdt.remove_node("/soc/nonexistent");
401    /// assert!(not_found.is_err());
402    /// # Ok::<(), fdt_raw::FdtError>(())
403    /// ```
404    pub fn remove_node(&mut self, path: &str) -> Result<Option<Node>, FdtError> {
405        let normalized_path = self.normalize_path(path).ok_or(FdtError::InvalidInput)?;
406
407        // 直接使用精确路径删除
408        let result = self.root.remove_by_path(&normalized_path)?;
409
410        // 如果删除成功且结果是 None,说明路径不存在
411        if result.is_none() {
412            return Err(FdtError::NotFound);
413        }
414
415        Ok(result)
416    }
417
418    /// 获取所有节点的深度优先迭代器
419    ///
420    /// 返回包含根节点及其所有子节点的迭代器,按照深度优先遍历顺序
421    pub fn all_nodes(&self) -> impl Iterator<Item = NodeRef<'_>> + '_ {
422        NodeIter::new(&self.root)
423    }
424
425    pub fn all_nodes_mut(&mut self) -> impl Iterator<Item = NodeMut<'_>> + '_ {
426        NodeIterMut::new(&mut self.root)
427    }
428
429    pub fn find_by_path<'a>(&'a self, path: &str) -> impl Iterator<Item = NodeRef<'a>> {
430        let path = self
431            .normalize_path(path)
432            .unwrap_or_else(|| path.to_string());
433
434        NodeIter::new(&self.root).filter_map(move |node_ref| {
435            if node_ref.path_eq_fuzzy(&path) {
436                Some(node_ref)
437            } else {
438                None
439            }
440        })
441    }
442
443    pub fn get_by_path<'a>(&'a self, path: &str) -> Option<NodeRef<'a>> {
444        let path = self.normalize_path(path)?;
445        NodeIter::new(&self.root).find_map(move |node_ref| {
446            if node_ref.path_eq(&path) {
447                Some(node_ref)
448            } else {
449                None
450            }
451        })
452    }
453
454    pub fn get_by_path_mut<'a>(&'a mut self, path: &str) -> Option<NodeMut<'a>> {
455        let path = self.normalize_path(path)?;
456        NodeIterMut::new(&mut self.root).find_map(move |node_mut| {
457            if node_mut.path_eq(&path) {
458                Some(node_mut)
459            } else {
460                None
461            }
462        })
463    }
464
465    pub fn find_compatible(&self, compatible: &[&str]) -> Vec<NodeRef<'_>> {
466        let mut results = Vec::new();
467        for node_ref in self.all_nodes() {
468            let Some(ls) = node_ref.compatible() else {
469                continue;
470            };
471
472            for comp in ls {
473                if compatible.contains(&comp) {
474                    results.push(node_ref);
475                    break;
476                }
477            }
478        }
479        results
480    }
481
482    /// 序列化为 FDT 二进制数据
483    pub fn encode(&self) -> FdtData {
484        FdtEncoder::new(self).encode()
485    }
486}
487
488impl core::fmt::Display for Fdt {
489    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
490        // 输出 DTS 头部信息
491        writeln!(f, "/dts-v1/;")?;
492
493        // 输出内存保留块
494        for reservation in &self.memory_reservations {
495            writeln!(
496                f,
497                "/memreserve/ 0x{:x} 0x{:x};",
498                reservation.address, reservation.size
499            )?;
500        }
501
502        // 输出根节点
503        writeln!(f, "{}", self.root)
504    }
505}
506
507impl core::fmt::Debug for Fdt {
508    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
509        if f.alternate() {
510            // Deep debug format with node traversal
511            self.fmt_debug_deep(f)
512        } else {
513            // Simple debug format (current behavior)
514            f.debug_struct("Fdt")
515                .field("boot_cpuid_phys", &self.boot_cpuid_phys)
516                .field("memory_reservations_count", &self.memory_reservations.len())
517                .field("root_node_name", &self.root.name)
518                .field("total_nodes", &self.root.children.len())
519                .field("phandle_cache_size", &self.phandle_cache.len())
520                .finish()
521        }
522    }
523}
524
525impl Fdt {
526    fn fmt_debug_deep(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
527        writeln!(f, "Fdt {{")?;
528        writeln!(f, "    boot_cpuid_phys: 0x{:x},", self.boot_cpuid_phys)?;
529        writeln!(
530            f,
531            "    memory_reservations_count: {},",
532            self.memory_reservations.len()
533        )?;
534        writeln!(f, "    phandle_cache_size: {},", self.phandle_cache.len())?;
535        writeln!(f, "    nodes:")?;
536
537        // 遍历所有节点并打印带缩进的调试信息
538        for (i, node) in self.all_nodes().enumerate() {
539            self.fmt_node_debug(f, &node, 2, i)?;
540        }
541
542        writeln!(f, "}}")
543    }
544
545    fn fmt_node_debug(
546        &self,
547        f: &mut core::fmt::Formatter<'_>,
548        node: &NodeRef,
549        indent: usize,
550        index: usize,
551    ) -> core::fmt::Result {
552        // 打印缩进
553        for _ in 0..indent {
554            write!(f, "    ")?;
555        }
556
557        // 打印节点索引和基本信息
558        write!(f, "[{:03}] {}: ", index, node.name())?;
559
560        // 根据节点类型打印特定信息
561        match node.as_ref() {
562            NodeKind::Clock(clock) => {
563                write!(f, "Clock")?;
564                if let ClockType::Fixed(fixed) = &clock.kind {
565                    write!(f, " (Fixed, {}Hz)", fixed.frequency)?;
566                } else {
567                    write!(f, " (Provider)")?;
568                }
569                if !clock.clock_output_names.is_empty() {
570                    write!(f, ", outputs: {:?}", clock.clock_output_names)?;
571                }
572                write!(f, ", cells={}", clock.clock_cells)?;
573            }
574            NodeKind::Pci(pci) => {
575                write!(f, "PCI")?;
576                if let Some(bus_range) = pci.bus_range() {
577                    write!(f, " (bus: {:?})", bus_range)?;
578                }
579                write!(f, ", interrupt-cells={}", pci.interrupt_cells())?;
580            }
581            NodeKind::InterruptController(ic) => {
582                write!(f, "InterruptController")?;
583                if let Some(cells) = ic.interrupt_cells() {
584                    write!(f, " (cells={})", cells)?;
585                }
586                let compatibles = ic.compatibles();
587                if !compatibles.is_empty() {
588                    write!(f, ", compatible: {:?}", compatibles)?;
589                }
590            }
591            NodeKind::Memory(mem) => {
592                write!(f, "Memory")?;
593                let regions = mem.regions();
594                if !regions.is_empty() {
595                    write!(f, " ({} regions", regions.len())?;
596                    for (i, region) in regions.iter().take(2).enumerate() {
597                        write!(f, ", [{}]: 0x{:x}+0x{:x}", i, region.address, region.size)?;
598                    }
599                    if regions.len() > 2 {
600                        write!(f, ", ...")?;
601                    }
602                    write!(f, ")")?;
603                }
604            }
605            NodeKind::Generic(_) => {
606                write!(f, "Generic")?;
607            }
608        }
609
610        // 打印 phandle 信息
611        if let Some(phandle) = node.phandle() {
612            write!(f, ", phandle={}", phandle)?;
613        }
614
615        // 打印地址和大小 cells 信息
616        if let Some(address_cells) = node.address_cells() {
617            write!(f, ", #address-cells={}", address_cells)?;
618        }
619        if let Some(size_cells) = node.size_cells() {
620            write!(f, ", #size-cells={}", size_cells)?;
621        }
622
623        writeln!(f)?;
624
625        Ok(())
626    }
627}