github_copilot_sdk/
subscription.rs1use std::fmt;
30use std::pin::Pin;
31use std::task::{Context, Poll};
32
33use tokio::sync::broadcast::Receiver;
34use tokio_stream::wrappers::BroadcastStream;
35use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
36use tokio_stream::{Stream, StreamExt as _};
37
38use crate::types::{SessionEvent, SessionLifecycleEvent};
39use crate::{Custom, Repr};
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub struct Lagged(pub(crate) u64);
50
51impl Lagged {
52 pub fn skipped(&self) -> u64 {
54 self.0
55 }
56}
57
58impl fmt::Display for Lagged {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 write!(f, "subscription lagged behind by {} events", self.0)
61 }
62}
63
64impl std::error::Error for Lagged {}
65
66#[derive(Clone, Copy, Debug, PartialEq, Eq)]
68#[non_exhaustive]
69pub enum RecvErrorKind {
70 Closed,
73
74 Lagged(Lagged),
76}
77
78impl fmt::Display for RecvErrorKind {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 match self {
81 RecvErrorKind::Closed => write!(f, "subscription closed"),
82 RecvErrorKind::Lagged(l) => write!(f, "{l}"),
83 }
84 }
85}
86
87#[derive(Debug)]
90pub struct RecvError {
91 repr: Repr<RecvErrorKind>,
92}
93
94impl RecvError {
95 pub fn kind(&self) -> &RecvErrorKind {
97 match &self.repr {
98 Repr::Simple(k) | Repr::SimpleMessage(k, ..) | Repr::Custom(Custom { kind: k, .. }) => {
99 k
100 }
101 }
102 }
103}
104
105impl fmt::Display for RecvError {
106 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107 match &self.repr {
108 Repr::Simple(k) => write!(f, "{k}"),
109 Repr::SimpleMessage(_, m) => write!(f, "{m}"),
110 Repr::Custom(Custom { error, .. }) => write!(f, "{error}"),
111 }
112 }
113}
114
115impl std::error::Error for RecvError {
116 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
117 match &self.repr {
118 Repr::Custom(Custom { error, .. }) => Some(&**error),
119 _ => None,
120 }
121 }
122}
123
124impl From<RecvErrorKind> for RecvError {
125 fn from(kind: RecvErrorKind) -> Self {
126 Self {
127 repr: Repr::Simple(kind),
128 }
129 }
130}
131
132impl From<Lagged> for RecvError {
133 fn from(lagged: Lagged) -> Self {
134 Self::from(RecvErrorKind::Lagged(lagged))
135 }
136}
137
138macro_rules! define_subscription {
139 (
140 $(#[$meta:meta])*
141 $name:ident, $item:ty $(,)?
142 ) => {
143 $(#[$meta])*
144 #[must_use = "subscriptions are inert until polled"]
145 pub struct $name {
146 inner: BroadcastStream<$item>,
147 }
148
149 impl $name {
150 pub(crate) fn new(rx: Receiver<$item>) -> Self {
151 Self {
152 inner: BroadcastStream::new(rx),
153 }
154 }
155
156 pub async fn recv(&mut self) -> Result<$item, RecvError> {
173 match self.inner.next().await {
174 Some(Ok(event)) => Ok(event),
175 Some(Err(BroadcastStreamRecvError::Lagged(n))) => {
176 Err(Lagged(n).into())
177 }
178 None => Err(RecvErrorKind::Closed.into()),
179 }
180 }
181 }
182
183 impl Stream for $name {
184 type Item = Result<$item, Lagged>;
185
186 fn poll_next(
187 mut self: Pin<&mut Self>,
188 cx: &mut Context<'_>,
189 ) -> Poll<Option<Self::Item>> {
190 match Pin::new(&mut self.inner).poll_next(cx) {
191 Poll::Ready(Some(Ok(event))) => Poll::Ready(Some(Ok(event))),
192 Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(n)))) => {
193 Poll::Ready(Some(Err(Lagged(n))))
194 }
195 Poll::Ready(None) => Poll::Ready(None),
196 Poll::Pending => Poll::Pending,
197 }
198 }
199 }
200 };
201}
202
203define_subscription! {
204 EventSubscription, SessionEvent
211}
212
213define_subscription! {
214 LifecycleSubscription, SessionLifecycleEvent
221}
222
223#[cfg(test)]
224mod tests {
225 use tokio::sync::broadcast;
226
227 use super::*;
228
229 fn make_event(id: &str) -> SessionEvent {
230 SessionEvent {
231 id: id.into(),
232 timestamp: "2025-01-01T00:00:00Z".into(),
233 parent_id: None,
234 ephemeral: None,
235 agent_id: None,
236 debug_cli_received_at_ms: None,
237 debug_ws_forwarded_at_ms: None,
238 event_type: "noop".into(),
239 data: serde_json::json!({}),
240 }
241 }
242
243 #[tokio::test]
244 async fn recv_yields_then_closes_on_drop_sender() {
245 let (tx, rx) = broadcast::channel(8);
246 let mut sub = EventSubscription::new(rx);
247 tx.send(make_event("a")).unwrap();
248 tx.send(make_event("b")).unwrap();
249 drop(tx);
250
251 assert_eq!(sub.recv().await.unwrap().id, "a");
252 assert_eq!(sub.recv().await.unwrap().id, "b");
253 assert!(matches!(
254 sub.recv().await.unwrap_err().kind(),
255 RecvErrorKind::Closed
256 ));
257 }
258
259 #[tokio::test]
260 async fn recv_surfaces_lag() {
261 let (tx, rx) = broadcast::channel(2);
262 let mut sub = EventSubscription::new(rx);
263 for id in ["a", "b", "c", "d"] {
264 tx.send(make_event(id)).unwrap();
265 }
266 let err = sub.recv().await.expect_err("expected a Lagged error");
267 let RecvErrorKind::Lagged(l) = err.kind() else {
268 panic!("expected Lagged, got {:?}", err.kind());
269 };
270 assert_eq!(l.skipped(), 2);
271 assert_eq!(sub.recv().await.unwrap().id, "c");
273 assert_eq!(sub.recv().await.unwrap().id, "d");
274 }
275
276 #[tokio::test]
277 async fn stream_impl_matches_recv_semantics() {
278 let (tx, rx) = broadcast::channel(8);
279 let mut sub = EventSubscription::new(rx);
280 tx.send(make_event("a")).unwrap();
281 drop(tx);
282
283 let next = sub.next().await;
285 assert_eq!(next.unwrap().unwrap().id, "a");
286 assert!(sub.next().await.is_none());
287 }
288}