1use crate::{DoIpCodec, DoIpTokioError, TCP_DATA_TLS_PORT};
2use async_trait::async_trait;
3use bytes::Bytes;
4use doip::{
5 ActivationType, AliveCheckResponse, DiagnosticEntityStatusResponse,
6 DiagnosticMessagePositiveAck, DiagnosticPowerMode, DoIpError, DoIpHeader,
7 FurtherActionRequired, NodeType, PayloadType, RoutingActivationRequest,
8 RoutingActivationResponse, VehicleIdentificationResponse, VinGidSyncStatus,
9};
10use futures::{SinkExt, StreamExt};
11use socket2::{Domain, Protocol, Socket, Type};
12use std::io::Cursor;
13use std::net::{IpAddr, SocketAddr};
14use std::sync::atomic::{AtomicU8, Ordering};
15use std::{io, sync::Arc};
16use thiserror::Error;
17use tls_api::TlsAcceptor;
18use tokio::net::{TcpListener, TcpStream, UdpSocket};
19use tokio_util::codec::Framed;
20use tracing::{debug, error};
21
22pub struct ClientContext {
24 pub addr: IpAddr,
26 pub logical_addr: u16,
29}
30
31#[async_trait]
32pub trait DoIpServerHandler<E> {
33 const VIN: [u8; 17];
35 const LOGICAL_ADDRESS: u16;
37 const EID: [u8; 6];
39 const GID: Option<[u8; 6]>;
42
43 async fn vehicle_identification_with_eid(
46 &self,
47 ctx: &ClientContext,
48 eid: &[u8; 6],
49 ) -> Result<VehicleIdentificationResponse, E> {
50 if Self::EID == *eid {
51 Ok(VehicleIdentificationResponse {
52 eid: Self::EID,
53 logical_address: Self::LOGICAL_ADDRESS.to_be_bytes(),
54 vin: Self::VIN,
55 gid: Self::GID,
56 further_action: FurtherActionRequired::NoFurtherActionRequried,
57 vin_gid_sync_status: VinGidSyncStatus::Synchronized,
58 })
59 } else {
60 todo!();
61 }
62 }
63
64 async fn vehicle_identification_with_vin(
67 &self,
68 ctx: &ClientContext,
69 vin: &[u8; 17],
70 ) -> Result<VehicleIdentificationResponse, E> {
71 if Self::VIN == *vin {
72 Ok(VehicleIdentificationResponse {
73 eid: Self::EID,
74 logical_address: Self::LOGICAL_ADDRESS.to_be_bytes(),
75 vin: Self::VIN,
76 gid: Self::GID,
77 further_action: FurtherActionRequired::NoFurtherActionRequried,
78 vin_gid_sync_status: VinGidSyncStatus::Synchronized,
79 })
80 } else {
81 todo!();
82 }
83 }
84
85 async fn routing_activation(
86 &self,
87 ctx: &ClientContext,
88 source_address: u16,
89 activation_type: ActivationType,
90 ) -> Result<RoutingActivationResponse, E>;
91
92 async fn alive_check(&self, ctx: &ClientContext) -> Result<AliveCheckResponse, E> {
93 Ok(AliveCheckResponse {
94 source_address: ctx.logical_addr,
95 })
96 }
97
98 async fn diagnostic_power_mode_information(
99 &self,
100 ctx: &ClientContext,
101 ) -> Result<DiagnosticPowerMode, E> {
102 Ok(DiagnosticPowerMode::NotSupported)
103 }
104
105 async fn diagnostic_message(
113 &self,
114 ctx: &ClientContext,
115 source_address: u16,
116 target_address: u16,
117 user_data: Vec<u8>,
118 ) -> Result<DiagnosticMessagePositiveAck, E>;
119}
120
121#[derive(Error, Debug)]
122pub enum ServerError {
123 #[error(transparent)]
124 Io(#[from] io::Error),
125 #[error("The server logical address: {0:X} is not within the valid range 0x0001 - 0x0DFF")]
126 InvalidServerLogicalAddr(u16),
127 #[error(transparent)]
130 Anyhow(#[from] anyhow::Error),
131 #[error(transparent)]
132 DoIp(#[from] DoIpError),
133 #[error(transparent)]
134 DoIpTokio(#[from] DoIpTokioError),
135 #[error("Unsupported payload type: {0:?}")]
136 Unsupported(PayloadType),
137}
138
139pub struct DoIpServer<T: DoIpServerHandler<ServerError>, TA: TlsAcceptor> {
140 handler: Arc<T>,
141 addr: IpAddr,
142 tls_acceptor: Arc<TA>,
143 currently_open_sockets: AtomicU8,
144}
145
146impl<T: DoIpServerHandler<ServerError> + std::marker::Sync, TA: TlsAcceptor> DoIpServer<T, TA> {
147 pub fn new(handler: T, addr: IpAddr, tls_acceptor: TA) -> Result<Self, ServerError> {
148 if T::LOGICAL_ADDRESS < 0x0001 || T::LOGICAL_ADDRESS > 0x0DFF {
149 return Err(ServerError::InvalidServerLogicalAddr(T::LOGICAL_ADDRESS));
150 }
151
152 Ok(Self {
153 handler: Arc::new(handler),
154 addr,
155 tls_acceptor: Arc::new(tls_acceptor),
156 currently_open_sockets: AtomicU8::new(0),
157 })
158 }
159
160 pub async fn serve(self) -> Result<(), ServerError> {
161 let udp_socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
163 udp_socket.set_reuse_address(true)?;
164 udp_socket.set_broadcast(true)?;
165
166 let client_addr_udp = SocketAddr::new(self.addr, 0);
168 udp_socket.bind(&client_addr_udp.into())?;
169 let udp_socket = UdpSocket::from_std(udp_socket.into())?;
170
171 let listener = TcpListener::bind(("0.0.0.0", TCP_DATA_TLS_PORT)).await?;
172
173 loop {
174 match listener.accept().await {
175 Ok((tcp_stream, client_socket_addr)) => {
176 if let Err(client_error) =
177 self.handle_client(client_socket_addr, tcp_stream).await
178 {
179 error!("Error occured: {client_error}");
180 }
181 }
182 Err(accept_error) => {
183 error!("Failed to accept new TCP client: {accept_error}");
184 }
185 }
186 }
187 }
188
189 #[tracing::instrument(skip(self, tcp_stream))]
190 async fn handle_client(
191 &self,
192 client_socket_addr: SocketAddr,
193 tcp_stream: TcpStream,
194 ) -> Result<(), ServerError> {
195 let currently_open_sockets = self.currently_open_sockets.fetch_add(1, Ordering::Relaxed);
196 debug!("New client connected addr: {client_socket_addr}, previous open sockets: {currently_open_sockets}");
197
198 let tls_stream = self.tls_acceptor.accept(tcp_stream).await?;
199 debug!("TLS accepted addr: {client_socket_addr}");
200
201 let mut client_tls_stream = Framed::new(tls_stream, DoIpCodec {});
202
203 loop {
204 match client_tls_stream.next().await {
205 Some(Ok((header, payload))) => {
206 let (response_header, response_payload) = self
207 .handle_client_message(client_socket_addr, header, payload)
208 .await?;
209
210 client_tls_stream
211 .send((&response_header, &response_payload))
212 .await?;
213 }
214 Some(Err(codec_error)) => {
215 error!("Client, decoding error source: {client_socket_addr}, {codec_error}")
216 }
217 None => {
218 debug!("Client stream closed, client addr: {client_socket_addr}");
219 self.currently_open_sockets.fetch_sub(1, Ordering::Relaxed);
220 return Ok(());
221 }
222 }
223 }
224 }
225
226 async fn handle_client_message(
227 &self,
228 client_socket_addr: SocketAddr,
229 header: DoIpHeader,
230 client_payload: Bytes,
231 ) -> Result<(DoIpHeader, Vec<u8>), ServerError> {
232 debug!("Received Client DoIp message: {header:?}");
233 let ctx = ClientContext {
234 addr: client_socket_addr.ip(),
235 logical_addr: 0x0000, };
237 let mut response_payload = Vec::new();
238
239 let response: Result<_, ServerError> = match header.payload_type {
240 PayloadType::AliveCheckRequest => {
241 let response = self.handler.alive_check(&ctx).await?;
242 response.write(&mut response_payload)?;
243
244 Ok((PayloadType::AliveCheckResponse, response_payload))
245 }
246 PayloadType::RoutingActivationRequest => {
247 let request = RoutingActivationRequest::read(&mut Cursor::new(client_payload))?;
248 let source_address = u16::from_be_bytes(request.source_address);
249 let response = self
250 .handler
251 .routing_activation(&ctx, source_address, request.activation_type)
252 .await?;
253
254 response.write(&mut response_payload)?;
255
256 Ok((PayloadType::RoutingActivationResponse, response_payload))
257 }
258 PayloadType::DoIpEntityStatusRequest => {
259 let response = DiagnosticEntityStatusResponse {
260 node_type: NodeType::DoIpNode,
261 max_open_sockets: u8::MAX,
262 currently_open_sockets: self.currently_open_sockets.load(Ordering::Relaxed),
263 max_data_size: u32::MAX,
264 };
265 response.write(&mut response_payload)?;
266 Ok((PayloadType::DoIpEntityStatusResponse, response_payload))
267 }
268 _ => Err(ServerError::Unsupported(header.payload_type)),
270 };
271
272 let (payload_type, payload) = response?;
273
274 let header = DoIpHeader::new(payload_type, payload.len() as u32);
275 Ok((header, payload))
276 }
277}