mctp_estack/
control.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2/*
3 * MCTP common types and traits.
4 *
5 * Copyright (c) 2024-2025 Code Construct
6 */
7
8//! MCTP Control Protocol implementation
9
10use crate::fmt::*;
11use crate::Router;
12use libmctp::control_packet::CompletionCode;
13use mctp::{AsyncRespChannel, Eid, Error, Listener, MsgIC, MsgType};
14use uuid::Uuid;
15
16pub use libmctp::control_packet::CommandCode;
17
18type Header = libmctp::control_packet::MCTPControlMessageHeader<[u8; 2]>;
19
20/// A `Result` with a MCTP Control Completion Code as error
21pub type ControlResult<T> =
22    core::result::Result<T, libmctp::control_packet::CompletionCode>;
23
24pub struct MctpControlMsg<'a> {
25    pub header: Header,
26    pub body: &'a [u8],
27}
28
29const MAX_MSG_SIZE: usize = 20; /* largest is Get Endpoint UUID */
30const MAX_MSG_TYPES: usize = 8;
31
32impl<'a> MctpControlMsg<'a> {
33    pub fn from_buf(buf: &'a [u8]) -> ControlResult<Self> {
34        if buf.len() < 2 {
35            return Err(CompletionCode::ErrorInvalidLength);
36        }
37        let b: [u8; 2] = buf[..2].try_into().unwrap();
38        let header = Header::new_from_buf(b);
39
40        if header.d() != 0 {
41            // Datagram bit is unhandled
42            return Err(CompletionCode::ErrorInvalidData);
43        }
44
45        let body = &buf[2..];
46        Ok(Self { header, body })
47    }
48
49    pub fn new_resp<'f>(
50        &self,
51        body: &'f [u8],
52    ) -> ControlResult<MctpControlMsg<'f>> {
53        if self.header.rq() == 0 {
54            return Err(CompletionCode::ErrorInvalidData);
55        }
56
57        let mut header = Header::new_from_buf(self.header.0);
58        header.set_rq(0);
59
60        Ok(MctpControlMsg { header, body })
61    }
62
63    pub fn slices(&self) -> [&[u8]; 2] {
64        [&self.header.0, self.body]
65    }
66
67    /// Extract the MCTP control message command code.
68    ///
69    /// Unrecognised values will return `Err`.
70    pub fn command_code(&self) -> core::result::Result<CommandCode, u8> {
71        let cc = self.header.command_code();
72        match CommandCode::from(cc) {
73            CommandCode::Unknown => Err(cc),
74            cmd => Ok(cmd),
75        }
76    }
77}
78
79pub fn respond_get_eid<'a>(
80    req: &MctpControlMsg,
81    eid: Eid,
82    medium_specific: u8,
83    rsp_buf: &'a mut [u8],
84) -> ControlResult<MctpControlMsg<'a>> {
85    if req.command_code() != Ok(CommandCode::GetEndpointID) {
86        return Err(CompletionCode::Error);
87    }
88    if !req.body.is_empty() {
89        return Err(CompletionCode::ErrorInvalidLength);
90    }
91    // simple endpoint, static EID supported
92    let endpoint_type = 0b0000_0001;
93    let body = [
94        CompletionCode::Success as u8,
95        eid.0,
96        endpoint_type,
97        medium_specific,
98    ];
99
100    let rsp_buf = &mut rsp_buf[0..body.len()];
101    rsp_buf.clone_from_slice(&body);
102    req.new_resp(rsp_buf)
103}
104
105#[derive(Debug)]
106pub struct SetEndpointId {
107    pub eid: Eid,
108    pub force: bool,
109    pub reset: bool,
110}
111
112pub fn parse_set_eid(req: &MctpControlMsg) -> ControlResult<SetEndpointId> {
113    if req.command_code() != Ok(CommandCode::SetEndpointID) {
114        return Err(CompletionCode::Error);
115    }
116    if req.body.len() != 2 {
117        return Err(CompletionCode::ErrorInvalidLength);
118    }
119
120    let eid = Eid::new_normal(req.body[1]).map_err(|_| {
121        warn!("Invalid Set EID {}", req.body[1]);
122        CompletionCode::ErrorInvalidData
123    })?;
124
125    let mut ret = SetEndpointId {
126        eid,
127        force: false,
128        reset: false,
129    };
130
131    match req.body[0] & 0x03 {
132        // Set
133        0b00 => (),
134        // Force
135        0b01 => ret.force = true,
136        // Reset
137        0b10 => ret.reset = true,
138        // Set Discovered
139        0b11 => return Err(CompletionCode::ErrorInvalidData),
140        _ => unreachable!(),
141    }
142
143    Ok(ret)
144}
145
146pub fn respond_set_eid<'a>(
147    req: &MctpControlMsg,
148    accepted: bool,
149    current_eid: Eid,
150    rsp_buf: &'a mut [u8],
151) -> ControlResult<MctpControlMsg<'a>> {
152    if req.command_code() != Ok(CommandCode::SetEndpointID) {
153        return Err(CompletionCode::Error);
154    }
155    let status = if accepted { 0b00000000 } else { 0b00010000 };
156    let body = [CompletionCode::Success as u8, status, current_eid.0, 0x00];
157    let rsp_buf = &mut rsp_buf[0..body.len()];
158    rsp_buf.clone_from_slice(&body);
159    req.new_resp(rsp_buf)
160}
161
162pub fn respond_get_uuid<'a>(
163    req: &MctpControlMsg,
164    uuid: Uuid,
165    rsp_buf: &'a mut [u8],
166) -> ControlResult<MctpControlMsg<'a>> {
167    if req.command_code() != Ok(CommandCode::GetEndpointUUID) {
168        return Err(CompletionCode::Error);
169    }
170
171    let mut body = [0u8; 1 + 16];
172    body[0] = CompletionCode::Success as u8;
173    body[1..].clone_from_slice(uuid.as_bytes());
174
175    let rsp_buf = &mut rsp_buf[0..body.len()];
176    rsp_buf.clone_from_slice(&body);
177    req.new_resp(rsp_buf)
178}
179
180pub fn respond_get_msg_types<'a>(
181    req: &MctpControlMsg,
182    msgtypes: &[MsgType],
183    rsp_buf: &'a mut [u8],
184) -> ControlResult<MctpControlMsg<'a>> {
185    if req.command_code() != Ok(CommandCode::GetMessageTypeSupport) {
186        return Err(CompletionCode::Error);
187    }
188    if !req.body.is_empty() {
189        return Err(CompletionCode::ErrorInvalidLength);
190    }
191    let n = msgtypes.len();
192    let body = rsp_buf.get_mut(..n + 2).ok_or(CompletionCode::Error)?;
193    body[0] = CompletionCode::Success as u8;
194    body[1] = n as u8;
195    for (i, t) in msgtypes.iter().enumerate() {
196        body[i + 2] = t.0;
197    }
198    req.new_resp(body)
199}
200
201pub fn respond_unimplemented<'a>(
202    req: &MctpControlMsg,
203    rsp_buf: &'a mut [u8],
204) -> mctp::Result<MctpControlMsg<'a>> {
205    respond_error(req, CompletionCode::ErrorUnsupportedCmd, rsp_buf)
206}
207
208/// Respond with an error completion code.
209///
210/// This returns a `mctp::Result` since failures can't be sent as a response.
211pub fn respond_error<'a>(
212    req: &MctpControlMsg,
213    err: CompletionCode,
214    rsp_buf: &'a mut [u8],
215) -> mctp::Result<MctpControlMsg<'a>> {
216    if err == CompletionCode::Success {
217        return Err(Error::BadArgument);
218    }
219    let body = [err as u8];
220    let rsp_buf = &mut rsp_buf[0..body.len()];
221    rsp_buf.clone_from_slice(&body);
222    req.new_resp(rsp_buf)
223        .map_err(|_| mctp::Error::InternalError)
224}
225
226pub fn mctp_control_rx_req<'f, 'l, L>(
227    listener: &'l mut L,
228    buf: &'f mut [u8],
229) -> mctp::Result<(MctpControlMsg<'f>, L::RespChannel<'l>)>
230where
231    L: Listener,
232{
233    let (typ, ic, buf, ch) = listener.recv(buf)?;
234    if ic.0 {
235        return Err(Error::InvalidInput);
236    }
237    if typ != mctp::MCTP_TYPE_CONTROL {
238        // Listener was bound to the wrong type?
239        return Err(Error::BadArgument);
240    }
241
242    let msg = MctpControlMsg::from_buf(buf).map_err(|_| Error::InvalidInput)?;
243    Ok((msg, ch))
244}
245
246/// A Control Message handler.
247pub struct MctpControl<'a> {
248    rsp_buf: [u8; MAX_MSG_SIZE],
249    types: heapless::Vec<MsgType, MAX_MSG_TYPES>,
250    uuid: Option<Uuid>,
251    router: &'a Router<'a>,
252}
253
254impl<'a> MctpControl<'a> {
255    pub fn new(router: &'a Router<'a>) -> Self {
256        Self {
257            rsp_buf: [0u8; MAX_MSG_SIZE],
258            types: heapless::Vec::new(),
259            uuid: None,
260            router,
261        }
262    }
263
264    pub async fn handle_async(
265        &mut self,
266        msg: &[u8],
267        mut resp_chan: impl AsyncRespChannel,
268    ) -> mctp::Result<()> {
269        let req = MctpControlMsg::from_buf(msg)
270            .map_err(|_| mctp::Error::InvalidInput)?;
271
272        let resp = match self.handle_req(&req).await {
273            Err(e) => {
274                debug!("Control error response {:?}", e);
275                respond_error(&req, e, &mut self.rsp_buf)
276            }
277            Ok(r) => Ok(r),
278        }?;
279
280        resp_chan.send_vectored(MsgIC(false), &resp.slices()).await
281    }
282
283    pub fn set_message_types(&mut self, types: &[MsgType]) -> mctp::Result<()> {
284        if types.len() > self.types.capacity() {
285            return Err(mctp::Error::NoSpace);
286        }
287        self.types.clear();
288        // We have already checked the length, so no Err here
289        let _ = self.types.extend_from_slice(types);
290        Ok(())
291    }
292
293    pub fn set_uuid(&mut self, uuid: &Uuid) {
294        let _ = self.uuid.insert(*uuid);
295    }
296
297    async fn handle_req(
298        &mut self,
299        req: &'_ MctpControlMsg<'_>,
300    ) -> ControlResult<MctpControlMsg<'_>> {
301        let cc = req.command_code().map_err(|cc| {
302            debug!("Unsupported control command {}", cc);
303            CompletionCode::ErrorUnsupportedCmd
304        })?;
305
306        #[cfg(feature = "log")]
307        debug!("Control request {:?}", cc);
308        match cc {
309            CommandCode::GetEndpointID => {
310                let eid = self.router.get_eid().await;
311                respond_get_eid(req, eid, 0, &mut self.rsp_buf)
312            }
313            CommandCode::SetEndpointID => {
314                let set = parse_set_eid(req)?;
315                let res = self.router.set_eid(set.eid).await;
316                let eid = self.router.get_eid().await;
317
318                respond_set_eid(req, res.is_ok(), eid, &mut self.rsp_buf)
319            }
320            CommandCode::GetEndpointUUID => {
321                if let Some(uuid) = self.uuid {
322                    respond_get_uuid(req, uuid, &mut self.rsp_buf)
323                } else {
324                    Err(CompletionCode::ErrorUnsupportedCmd)
325                }
326            }
327            CommandCode::GetMessageTypeSupport => respond_get_msg_types(
328                req,
329                self.types.as_slice(),
330                &mut self.rsp_buf,
331            ),
332            _ => Err(CompletionCode::ErrorUnsupportedCmd),
333        }
334    }
335}