unmp/net/
mod.rs

1//! # 网络层
2//!
3//! 网络层(net)规定了数据包格式(Packet),负责维护路由表
4
5mod history;
6pub mod packet;
7mod route;
8
9use crate::consts::*;
10use crate::id::*;
11use crate::protocol;
12use crate::Connection;
13use futures_intrusive::sync::GenericMutex;
14use history::History;
15use log::{info, warn};
16use packet::{Packet, PacketHeader};
17pub use route::DisconnectCb;
18use route::Router;
19use spin::{Lazy, Once};
20use unmp_link::Link;
21
22type Mutex<T> = GenericMutex<spin::Mutex<()>, T>;
23
24static ROUTER: Lazy<Mutex<Router>> = Lazy::new(|| Mutex::new(Router::new(), false));
25static CFG: Once<Config> = Once::new();
26static HISTORY: History = History::new();
27
28/// 配置
29pub struct Config {
30    /// 本设备ID
31    id: Id,
32    /// 转发开关
33    relay: bool,
34}
35impl Config {
36    pub fn new(id: Id) -> Self {
37        Config {
38            id: id,
39            relay: false,
40        }
41    }
42    pub fn id(&self) -> &Id {
43        &self.id
44    }
45    pub fn set_id(&mut self, id: Id) {
46        self.id = id;
47    }
48    pub fn relay(&self) -> bool {
49        self.relay
50    }
51    pub fn set_relay(&mut self, relay: bool) {
52        self.relay = relay;
53    }
54}
55
56/// 初始化
57pub fn init(cfg: Config) {
58    CFG.call_once(|| cfg);
59    unmp_link::on_destroy(when_link_disconnect);
60}
61/// 读取本机ID
62pub fn get_id() -> Id {
63    return CFG.get().unwrap().id.clone();
64}
65/// 通用发送函数
66async fn send_common(
67    head: &PacketHeader,
68    data: &[u8],
69    dst: Option<&Id>,
70    dst_link: Option<&Link>,
71    exclude: Option<&Link>,
72) -> Result<(), ()> {
73    let buf = head.generate(data);
74    info!("net send: {:02X?}.", buf);
75    let id: &Id = if let Some(id) = dst { id } else { head.dst() };
76    if let Some(link) = dst_link {
77        // 指定了目标链路
78        if let Some(origin) = exclude {
79            if origin == link {
80                return Err(());
81            }
82        }
83        if let Err(_) = link.send(&buf).await {
84            return Err(());
85        } else {
86            return Ok(());
87        }
88    } else {
89        // 查找目标链路
90        if dst == Some(&ID_ALL) {
91            return unmp_link::broadcast(&buf).await;
92        }
93        let mut router = ROUTER.lock().await;
94        let result_child = router.send_to(&buf, &id, exclude).await;
95        let result_parent = router.send_to_parent(&buf, exclude).await;
96        match (result_child, result_parent) {
97            (Err(_), Err(_)) => return Err(()),
98            _ => return Ok(()),
99        }
100    }
101}
102/// 发送数据包到指定设备
103pub(crate) async fn send(protocol: u8, data: &[u8], conn: &Connection) -> Result<(), ()> {
104    let mut head = PacketHeader::new(get_id(), conn.id().clone());
105    head.set_ttl(if conn.id() == &ID_ALL { 0 } else { 8 });
106    head.set_protocol(protocol);
107    return send_common(&head, &data, Some(conn.id()), conn.link(), None).await;
108}
109/// 收到了数据包
110pub fn when_recv(link: &Link, buf: &[u8]) {
111    info!("net recv: {:02X?}, from {}.", buf, link);
112    let pkt = match Packet::parse(buf) {
113        Ok(result) => result,
114        Err(err) => {
115            warn!("net recv err: {}", err);
116            return;
117        }
118    };
119    if !HISTORY.add(&pkt) {
120        info!("net packet repeat.");
121        return;
122    }
123    let mut head = pkt.head().clone();
124    let data = pkt.data();
125
126    if head.src() != &ID_PARENT && head.src() != &ID_ALL {
127        connect(&head.src(), link.clone());
128    }
129
130    let cfg = CFG.get().unwrap();
131    if head.dst() == &cfg.id || head.dst() == &ID_PARENT || head.dst() == &ID_ALL {
132        let mut conn = Connection::new(head.src().clone());
133        conn.set_link(link.clone());
134        let protocol_id = head.protocol();
135        let data = VecData::from(data);
136        protocol::distribute(protocol_id, conn, &data);
137    } else if cfg.relay {
138        if head.ttl() > 0 {
139            head.set_ttl(head.ttl() - 1);
140            let data = VecData::from(data);
141            let link = link.clone();
142            task_stream::spawn(async move {
143                if let Err(()) = send_common(&head, &data, None, None, Some(&link)).await {
144                    info!("can't forward");
145                }
146            });
147        } else {
148            info!("unmp TTL is 0.");
149        }
150    }
151}
152/// 更新路由
153pub fn connect(id: &Id, link: Link) {
154    let id = id.clone();
155    task_stream::spawn(async move {
156        ROUTER.lock().await.add(&id, link);
157    });
158}
159/// 链路断开
160fn when_link_disconnect(link: &Link) {
161    let link = link.clone();
162    task_stream::spawn(async move {
163        ROUTER.lock().await.when_link_disconnect(&link);
164    });
165}
166/// 注册设备断开回调
167pub fn on_disconnect(cb: DisconnectCb) {
168    task_stream::spawn(async move {
169        ROUTER.lock().await.on_disconnect(cb);
170    });
171}