playit_agent_core/agent_control/
connected_control.rs

1use std::{net::SocketAddr, time::Duration};
2
3use message_encoding::MessageEncoding;
4use playit_agent_proto::{control_feed::ControlFeed, control_messages::{AgentRegistered, ControlRequest, ControlResponse, Ping, Pong}, raw_slice::RawSlice, rpc::ControlRpcMessage};
5
6use crate::utils::now_milli;
7
8use super::{errors::{ControlError, SetupError}, established_control::EstablishedControl, AuthResource, PacketIO};
9
10#[derive(Debug)]
11pub struct ConnectedControl<IO: PacketIO> {
12    pub(super) control_addr: SocketAddr,
13    pub(super) packet_io: IO,
14    pub(super) pong_latest: Pong,
15    pub(super) buffer: Vec<u8>,
16}
17
18impl<IO: PacketIO> ConnectedControl<IO> {
19    pub fn new(control_addr: SocketAddr, udp: IO, pong: Pong) -> Self {
20        ConnectedControl { control_addr, packet_io: udp, pong_latest: pong, buffer: Vec::with_capacity(1024) }
21    }
22
23    pub fn control_addr(&self) -> SocketAddr {
24        self.control_addr
25    }
26
27    pub fn pong(&self) -> Pong {
28        self.pong_latest.clone()
29    }
30
31    pub async fn auth_into_established<A: AuthResource>(mut self, auth: A) -> Result<EstablishedControl<A, IO>, SetupError> {
32        let registered = self.authenticate(&auth).await?;
33        Ok(self.into_established(auth, registered))
34    }
35
36    pub fn into_established<A: AuthResource>(self, auth: A, registered: AgentRegistered) -> EstablishedControl<A, IO> {
37        let pong = self.pong_latest.clone();
38
39        EstablishedControl {
40            auth,
41            conn: self,
42            pong_at_auth: pong,
43            registered,
44            current_ping: None,
45            clock_offset: 0,
46            force_expired: false,
47        }
48    }
49
50    pub fn reset_established<A: AuthResource>(self, established: &mut EstablishedControl<A, IO>, registered: AgentRegistered) {
51        established.registered = registered;
52        established.pong_at_auth = self.pong_latest.clone();
53        established.conn = self;
54        established.current_ping = None;
55        established.force_expired = false;
56    }
57
58    pub async fn authenticate<A: AuthResource>(&mut self, auth: &A) -> Result<AgentRegistered, SetupError> {
59        let auth_pong = self.pong_latest.clone();
60        let res = auth.authenticate(&auth_pong).await?;
61
62        let bytes = match hex::decode(&res.key) {
63            Ok(data) => data,
64            Err(_) => return Err(SetupError::FailedToDecodeSignedAgentRegisterHex),
65        };
66
67        let request_id = now_milli();
68
69        for _ in 0..5 {
70            self.send(&ControlRpcMessage {
71                request_id,
72                content: RawSlice(&bytes),
73            }).await?;
74
75            for _ in 0..5 {
76                let mesage = match tokio::time::timeout(Duration::from_millis(500), self.recv()).await {
77                    Ok(Ok(msg)) => msg,
78                    Ok(Err(error)) => {
79                        tracing::error!(?error, "got error reading from socket");
80                        break;
81                    }
82                    Err(_) => {
83                        tracing::error!("timeout waiting for register response");
84                        continue;
85                    }
86                };
87
88                let response = match mesage {
89                    ControlFeed::Response(response) if response.request_id == request_id => response,
90                    other => {
91                        tracing::error!(?other, "got unexpected response from register request");
92                        continue;
93                    }
94                };
95
96                return match response.content {
97                    ControlResponse::AgentRegistered(registered) => Ok(registered),
98                    ControlResponse::InvalidSignature => Err(SetupError::RegisterInvalidSignature),
99                    ControlResponse::Unauthorized => {
100                        /* most likely due to a changed client addr, send pong to refresh value */
101                        let _ = self.send(&ControlRpcMessage {
102                            request_id,
103                            content: ControlRequest::Ping(Ping {
104                                now: now_milli(),
105                                current_ping: None,
106                                session_id: None,
107                            }),
108                        }).await;
109
110                        Err(SetupError::RegisterUnauthorized)
111                    },
112                    ControlResponse::Pong(pong) => {
113                        if pong.client_addr != auth_pong.client_addr || pong.tunnel_addr != auth_pong.tunnel_addr {
114                            Err(SetupError::AttemptingToAuthWithOldFlow)
115                        } else {
116                            continue;
117                        }
118                    }
119                    ControlResponse::RequestQueued => {
120                        tracing::info!("register queued, waiting 1s");
121                        tokio::time::sleep(Duration::from_secs(1)).await;
122                        break;
123                    }
124                    other => {
125                        tracing::error!(?other, "expected AgentRegistered but got something else");
126                        continue;
127                    }
128                };
129            }
130        }
131
132        Err(SetupError::FailedToConnect)
133    }
134
135    pub async fn send<M: MessageEncoding>(&mut self, msg: &M) -> std::io::Result<()> {
136        self.buffer.clear();
137        msg.write_to(&mut self.buffer)?;
138        self.packet_io.send_to(&self.buffer, self.control_addr).await?;
139        Ok(())
140    }
141
142    pub async fn recv(&mut self) -> Result<ControlFeed, ControlError> {
143        self.buffer.resize(1024, 0);
144
145        let (bytes, remote) = self.packet_io.recv_from(&mut self.buffer).await?;
146        if remote != self.control_addr {
147            return Err(ControlError::InvalidRemote { expected: self.control_addr, got: remote });
148        }
149
150        let mut reader = &self.buffer[..bytes];
151        let feed = ControlFeed::read_from(&mut reader).map_err(|e| ControlError::FailedToReadControlFeed(e))?;
152
153        if let ControlFeed::Response(ControlRpcMessage { content: ControlResponse::Pong(pong), .. }) = &feed {
154            self.pong_latest = pong.clone();
155        }
156
157        Ok(feed)
158    }
159}
160