doip_tokio/
server.rs

1use crate::{DoIpCodec, DoIpTokioError, TCP_DATA_TLS_PORT};
2use async_trait::async_trait;
3use bytes::Bytes;
4use doip::{
5    ActivationType, AliveCheckResponse, DiagnosticEntityStatusResponse,
6    DiagnosticMessagePositiveAck, DiagnosticPowerMode, DoIpError, DoIpHeader,
7    FurtherActionRequired, NodeType, PayloadType, RoutingActivationRequest,
8    RoutingActivationResponse, VehicleIdentificationResponse, VinGidSyncStatus,
9};
10use futures::{SinkExt, StreamExt};
11use socket2::{Domain, Protocol, Socket, Type};
12use std::io::Cursor;
13use std::net::{IpAddr, SocketAddr};
14use std::sync::atomic::{AtomicU8, Ordering};
15use std::{io, sync::Arc};
16use thiserror::Error;
17use tls_api::TlsAcceptor;
18use tokio::net::{TcpListener, TcpStream, UdpSocket};
19use tokio_util::codec::Framed;
20use tracing::{debug, error};
21
22/// Details of the requesting client
23pub struct ClientContext {
24    /// Client's IP address
25    pub addr: IpAddr,
26    /// Client's logical address aka source address.
27    /// Valid range: 0x0E00 - 0x0FFF
28    pub logical_addr: u16,
29}
30
31#[async_trait]
32pub trait DoIpServerHandler<E> {
33    /// Vehicle Identification Number
34    const VIN: [u8; 17];
35    // valid range 0x0001 - 0x0DFF
36    const LOGICAL_ADDRESS: u16;
37    /// Unique entitiy identification (EID), e.g. MAC address of network interface.
38    const EID: [u8; 6];
39    //// Unique group identification of entities within a vehicle.
40    /// None when value not set (as indicated by `0x00` or `0xFF`).
41    const GID: Option<[u8; 6]>;
42
43    /// Identify vehicle by Entity ID (EID).
44    /// * `eid` - Unique DoIP enitity ID (e.g. MAC address)
45    async fn vehicle_identification_with_eid(
46        &self,
47        ctx: &ClientContext,
48        eid: &[u8; 6],
49    ) -> Result<VehicleIdentificationResponse, E> {
50        if Self::EID == *eid {
51            Ok(VehicleIdentificationResponse {
52                eid: Self::EID,
53                logical_address: Self::LOGICAL_ADDRESS.to_be_bytes(),
54                vin: Self::VIN,
55                gid: Self::GID,
56                further_action: FurtherActionRequired::NoFurtherActionRequried,
57                vin_gid_sync_status: VinGidSyncStatus::Synchronized,
58            })
59        } else {
60            todo!();
61        }
62    }
63
64    /// Identify vehicle by Vehicle Identification Number (VIN).
65    /// * `vin` - VIN as defined in ISO 3779.
66    async fn vehicle_identification_with_vin(
67        &self,
68        ctx: &ClientContext,
69        vin: &[u8; 17],
70    ) -> Result<VehicleIdentificationResponse, E> {
71        if Self::VIN == *vin {
72            Ok(VehicleIdentificationResponse {
73                eid: Self::EID,
74                logical_address: Self::LOGICAL_ADDRESS.to_be_bytes(),
75                vin: Self::VIN,
76                gid: Self::GID,
77                further_action: FurtherActionRequired::NoFurtherActionRequried,
78                vin_gid_sync_status: VinGidSyncStatus::Synchronized,
79            })
80        } else {
81            todo!();
82        }
83    }
84
85    async fn routing_activation(
86        &self,
87        ctx: &ClientContext,
88        source_address: u16,
89        activation_type: ActivationType,
90    ) -> Result<RoutingActivationResponse, E>;
91
92    async fn alive_check(&self, ctx: &ClientContext) -> Result<AliveCheckResponse, E> {
93        Ok(AliveCheckResponse {
94            source_address: ctx.logical_addr,
95        })
96    }
97
98    async fn diagnostic_power_mode_information(
99        &self,
100        ctx: &ClientContext,
101    ) -> Result<DiagnosticPowerMode, E> {
102        Ok(DiagnosticPowerMode::NotSupported)
103    }
104
105    // TODO currenly open sockets have to be tracked by implementor
106    // if this can be overridden
107    // async fn diagnostic_status_entity(
108    //     &self,
109    //     ctx: &ClientContext,
110    // ) -> Result<DiagnosticEntityStatusResponse, E>;
111
112    async fn diagnostic_message(
113        &self,
114        ctx: &ClientContext,
115        source_address: u16,
116        target_address: u16,
117        user_data: Vec<u8>,
118    ) -> Result<DiagnosticMessagePositiveAck, E>;
119}
120
121#[derive(Error, Debug)]
122pub enum ServerError {
123    #[error(transparent)]
124    Io(#[from] io::Error),
125    #[error("The server logical address: {0:X} is not within the valid range 0x0001 - 0x0DFF")]
126    InvalidServerLogicalAddr(u16),
127    // #[error(transparent)]
128    // Ssl(#[from] openssl::error::ErrorStack),
129    #[error(transparent)]
130    Anyhow(#[from] anyhow::Error),
131    #[error(transparent)]
132    DoIp(#[from] DoIpError),
133    #[error(transparent)]
134    DoIpTokio(#[from] DoIpTokioError),
135    #[error("Unsupported payload type: {0:?}")]
136    Unsupported(PayloadType),
137}
138
139pub struct DoIpServer<T: DoIpServerHandler<ServerError>, TA: TlsAcceptor> {
140    handler: Arc<T>,
141    addr: IpAddr,
142    tls_acceptor: Arc<TA>,
143    currently_open_sockets: AtomicU8,
144}
145
146impl<T: DoIpServerHandler<ServerError> + std::marker::Sync, TA: TlsAcceptor> DoIpServer<T, TA> {
147    pub fn new(handler: T, addr: IpAddr, tls_acceptor: TA) -> Result<Self, ServerError> {
148        if T::LOGICAL_ADDRESS < 0x0001 || T::LOGICAL_ADDRESS > 0x0DFF {
149            return Err(ServerError::InvalidServerLogicalAddr(T::LOGICAL_ADDRESS));
150        }
151
152        Ok(Self {
153            handler: Arc::new(handler),
154            addr,
155            tls_acceptor: Arc::new(tls_acceptor),
156            currently_open_sockets: AtomicU8::new(0),
157        })
158    }
159
160    pub async fn serve(self) -> Result<(), ServerError> {
161        // Tokio's UdpSocket does not directly offer "set_reuse_address", go with socket2
162        let udp_socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
163        udp_socket.set_reuse_address(true)?;
164        udp_socket.set_broadcast(true)?;
165
166        // The port zero indicates that a random, free port is chosen.
167        let client_addr_udp = SocketAddr::new(self.addr, 0);
168        udp_socket.bind(&client_addr_udp.into())?;
169        let udp_socket = UdpSocket::from_std(udp_socket.into())?;
170
171        let listener = TcpListener::bind(("0.0.0.0", TCP_DATA_TLS_PORT)).await?;
172
173        loop {
174            match listener.accept().await {
175                Ok((tcp_stream, client_socket_addr)) => {
176                    if let Err(client_error) =
177                        self.handle_client(client_socket_addr, tcp_stream).await
178                    {
179                        error!("Error occured: {client_error}");
180                    }
181                }
182                Err(accept_error) => {
183                    error!("Failed to accept new TCP client: {accept_error}");
184                }
185            }
186        }
187    }
188
189    #[tracing::instrument(skip(self, tcp_stream))]
190    async fn handle_client(
191        &self,
192        client_socket_addr: SocketAddr,
193        tcp_stream: TcpStream,
194    ) -> Result<(), ServerError> {
195        let currently_open_sockets = self.currently_open_sockets.fetch_add(1, Ordering::Relaxed);
196        debug!("New client connected addr: {client_socket_addr}, previous open sockets: {currently_open_sockets}");
197
198        let tls_stream = self.tls_acceptor.accept(tcp_stream).await?;
199        debug!("TLS accepted addr: {client_socket_addr}");
200
201        let mut client_tls_stream = Framed::new(tls_stream, DoIpCodec {});
202
203        loop {
204            match client_tls_stream.next().await {
205                Some(Ok((header, payload))) => {
206                    let (response_header, response_payload) = self
207                        .handle_client_message(client_socket_addr, header, payload)
208                        .await?;
209
210                    client_tls_stream
211                        .send((&response_header, &response_payload))
212                        .await?;
213                }
214                Some(Err(codec_error)) => {
215                    error!("Client, decoding error source: {client_socket_addr}, {codec_error}")
216                }
217                None => {
218                    debug!("Client stream closed, client addr: {client_socket_addr}");
219                    self.currently_open_sockets.fetch_sub(1, Ordering::Relaxed);
220                    return Ok(());
221                }
222            }
223        }
224    }
225
226    async fn handle_client_message(
227        &self,
228        client_socket_addr: SocketAddr,
229        header: DoIpHeader,
230        client_payload: Bytes,
231    ) -> Result<(DoIpHeader, Vec<u8>), ServerError> {
232        debug!("Received Client DoIp message: {header:?}");
233        let ctx = ClientContext {
234            addr: client_socket_addr.ip(),
235            logical_addr: 0x0000, // TODO fix this constant
236        };
237        let mut response_payload = Vec::new();
238
239        let response: Result<_, ServerError> = match header.payload_type {
240            PayloadType::AliveCheckRequest => {
241                let response = self.handler.alive_check(&ctx).await?;
242                response.write(&mut response_payload)?;
243
244                Ok((PayloadType::AliveCheckResponse, response_payload))
245            }
246            PayloadType::RoutingActivationRequest => {
247                let request = RoutingActivationRequest::read(&mut Cursor::new(client_payload))?;
248                let source_address = u16::from_be_bytes(request.source_address);
249                let response = self
250                    .handler
251                    .routing_activation(&ctx, source_address, request.activation_type)
252                    .await?;
253
254                response.write(&mut response_payload)?;
255
256                Ok((PayloadType::RoutingActivationResponse, response_payload))
257            }
258            PayloadType::DoIpEntityStatusRequest => {
259                let response = DiagnosticEntityStatusResponse {
260                    node_type: NodeType::DoIpNode,
261                    max_open_sockets: u8::MAX,
262                    currently_open_sockets: self.currently_open_sockets.load(Ordering::Relaxed),
263                    max_data_size: u32::MAX,
264                };
265                response.write(&mut response_payload)?;
266                Ok((PayloadType::DoIpEntityStatusResponse, response_payload))
267            }
268            // TODO add remaining
269            _ => Err(ServerError::Unsupported(header.payload_type)),
270        };
271
272        let (payload_type, payload) = response?;
273
274        let header = DoIpHeader::new(payload_type, payload.len() as u32);
275        Ok((header, payload))
276    }
277}