1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
#![no_std] use core::{ future::Future, mem, pin::Pin, task::{Context, Poll}, }; use futures::{ channel::oneshot, future::{self, FusedFuture}, ready, stream::Skip, FutureExt, Sink, SinkExt, Stream, StreamExt, }; type ChainSend<St> = oneshot::Sender<ResolverChainItem<St>>; /// ChainRecv is a helper wrapper around a Receiver of ResolverChainItem. It's /// Future designed to handle the logic for those items; it aggregates the /// skips and updates the receiver, until the actual stream item arrives, which /// it then resolves to. #[derive(Debug)] struct ChainRecv<St> { recv: oneshot::Receiver<ResolverChainItem<St>>, skip: usize, } impl<St> ChainRecv<St> { fn new(recv: oneshot::Receiver<ResolverChainItem<St>>) -> Self { Self { recv, skip: 0 } } } impl<St> Future for ChainRecv<St> { type Output = Option<(St, usize)>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { let this = self.get_mut(); match ready!(this.recv.poll_unpin(cx)) { Err(..) => Poll::Ready(None), Ok(ResolverChainItem::Reconnect(recv)) => { this.recv = recv.recv; this.skip += recv.skip; cx.waker().wake_by_ref(); Poll::Pending } Ok(ResolverChainItem::Stream { stream, skip }) => { Poll::Ready(Some((stream, this.skip + skip))) } } } } #[derive(Debug)] enum ResolverChainItem<St> { Reconnect(ChainRecv<St>), Stream { stream: St, skip: usize }, } #[derive(Debug)] enum ResolverState<St> { Chain { recv: ChainRecv<St>, send: ChainSend<St>, }, Stream { skip: usize, stream: St, send: ChainSend<St>, }, // We should only be in the dead state transiently between transitions, // or after we returned Poll::Ready Dead, } impl<St> ResolverState<St> { /// Replace the state with Dead and return the previous state #[inline] fn take(&mut self) -> Self { mem::replace(self, ResolverState::Dead) } } /// A [`Future`] associated with a request submitted through a [`Pipeline`]. /// /// When you successfully submit a request to a [`Pipeline`], it returns a /// `Resolver` that can be awaited to retrieve the response for that request. /// Because responses are retrieved lazily and in order, *each* `Resolver` must /// be awaited in order to receive the responses; later Resolvers will block /// indefinitely until earlier Resolvers have returned their responses. /// /// If the [`Stream`] used by the [`Pipeline`] to retrieve responses closes /// prematurely, all remaining (and new) Resolvers will return `None`. Ideally /// this shouldn't happen; the stream should return Some(Err(...)) to each /// resolver in the event of (for example) an unrecoverable connection failure. /// /// If a `Resolver` is dropped, the response associated with it will simply /// be discarded. #[derive(Debug)] #[must_use = "Resolvers do nothing unless polled"] pub struct Resolver<St: Stream + Unpin> { state: ResolverState<St>, } impl<St: Stream + Unpin> Future for Resolver<St> { type Output = Option<St::Item>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { let this = self.get_mut(); // Some design notes for this function: // - We try to be as conservative as possible with writes. This means // we don't modify the state until we *know* it's changing; we don't // .take() the state, process it, and restore it. This allows us // to ensure the state remains consistent even through ready! calls // and possible panics. // - We could loop here until we receive a Pending from one of our // inner polls or until we're ready, but we don't want to starve // the event loop, so we do at most one state update, then auto // awaken the context if we can continue to make progress. match this.state { ResolverState::Dead => panic!("Can't re-poll a completed future"), // The chain state means a previous receiver in the chain has the stream // right now, and we're waiting for it to eventually come to use. Additionally, // We may receive "reconnect" messages, which indicate that our previous // resolver is aborting and is sending ITS previous resolver to ensure the // chain isn't broken (along with a skip count, indicating the total number // of items from the stream that are associated with aborted resolvers and // need to be discarded). ResolverState::Chain { ref mut recv, .. } => { match ready!(recv.poll_unpin(cx)) { // Our channel was closed without a send, indicating the stream returned // None at some point. Clear our state to propagate the None to future // Resolvers, then return it. None => { this.state = ResolverState::Dead; Poll::Ready(None) } // The previous Resolver has finished, which means it's our turn to drink // from the stream. It's sent us the stream, along with a skip count in // the event it aborted early. Some((stream, skip)) => match this.state.take() { ResolverState::Chain { send, .. } => { this.state = ResolverState::Stream { stream, send, skip }; cx.waker().wake_by_ref(); Poll::Pending } _ => unreachable!(), }, } } // We are the current holder of the stream, so we're waiting for our element. // If we have a skip, that means previous Resolvers aborted without resolving, // which means that we need to take and discard that many elements before // claiming our own. ResolverState::Stream { ref mut stream, ref mut skip, .. } => { match ready!(stream.poll_next_unpin(cx)) { // Stream ended. Clear the state and return the None. // Clearing the state will close the send channel, which // will in activate the next Resolver in the chain and // so on. None => { this.state = ResolverState::Dead; Poll::Ready(None) } // We got an item, but we still have skips, which means // it's an item associated with a previous aborted Resolver. // Update skip and retry the loop. Some(..) if *skip > 0 => { *skip -= 1; cx.waker().wake_by_ref(); Poll::Pending } // We got our item! Send the stream down the line, then // clear our own state and return it Some(item) => match this.state.take() { ResolverState::Stream { stream, send, .. } => { // If the send channel is closed, that means that // it was part of the pipeline, which was dropped. // we can therefore silently let this send fail. let _ = send.send(ResolverChainItem::Stream { stream, skip: 0 }); Poll::Ready(Some(item)) } _ => unreachable!(), }, } } } } } impl<St: Stream + Unpin> FusedFuture for Resolver<St> { fn is_terminated(&self) -> bool { matches!(self.state, ResolverState::Dead) } } impl<St: Stream + Unpin> Drop for Resolver<St> { fn drop(&mut self) { // When a resolver is dropped, we need to make sure the chain is // unbroken, so we forward our state via a ResolverChainItem to the // next Resolver in the chain, along with an incremented skip so that // it knows to skip our response. match self.state.take() { ResolverState::Chain { mut recv, send } => { recv.skip += 1; let _ = send.send(ResolverChainItem::Reconnect(recv)); } ResolverState::Stream { stream, send, skip } => { let _ = send.send(ResolverChainItem::Stream { stream, skip: skip + 1, }); } ResolverState::Dead => {} }; } } /// A `Pipeline` manages sending requests through a stream and retrieving their /// matching responses. /// /// A pipeline manages request/response flow through a [`Sink`] and associated /// [`Stream`]. The two halves should be set up such that: /// /// - Items sent into the [`Sink`] are submitted to some underlying system as /// requests. /// - That system replies to each request with a response, in order, via the /// [`Stream`]. /// /// This could be HTTP requests sent through a `Keep-Alive` connection, Redis /// interactions through the [Redis Protocol], or anything else. /// /// The `Pipeline` provides a [`submit`][Pipeline::submit] method, which /// submits a request. Pipelines are backpressure sensitive, so this method /// will block until the underlying [`Sink`] can accept it. Pipelines do not do /// any extra buffering, so if you're using it to enqueue several requests at /// once, be sure that the [`Sink`] has been set up with its own buffering. /// /// The [`submit`][Pipeline::submit] method will return a [`Resolver`] /// associated with that request. This `Resolver` is a future which will, when /// awaited, resolve to the response associated with the Request. These /// resolvers must be awaited or dropped in order to consume responses from the /// underlying stream; be sure to set up concurrency or buffering to ensure /// your request submissions don't get stuck because the system is waiting for /// a response to be collected. /// /// [Redis Protocol]: https://redis.io/topics/protocol #[derive(Debug)] pub struct Pipeline<Si, St> { sink: Si, recv: oneshot::Receiver<ResolverChainItem<St>>, } impl<Si: Unpin, St: Unpin + Stream> Pipeline<Si, St> { /// Construct a new `Pipeline` with associated channels for requests and /// responses. In order for the Pipeline's logic to function correctly, /// this pair must be set up such that each request submitted through /// `requests` eventually results in a `response` being sent back through /// `responses`, in order. pub fn new<T>(requests: Si, responses: St) -> Self where Si: Sink<T>, { let (send, recv) = oneshot::channel(); send.send(ResolverChainItem::Stream { stream: responses, skip: 0, }) .unwrap_or_else(|_| unreachable!()); Self { sink: requests, recv, } } /// Submit a request to this `Pipeline`, blocking until it can be sent to /// the underlying `Sink`. Returns a [`Resolver`] that can be used to /// await the response, or an error if the `Sink` returned an error. pub async fn submit<T>(&mut self, item: T) -> Result<Resolver<St>, Si::Error> where Si: Sink<T>, { future::poll_fn(|cx| self.sink.poll_ready_unpin(cx)).await?; self.sink.start_send_unpin(item)?; let (send, recv) = oneshot::channel(); // Swap out the receive end. We now have a receive end connected to the // previous Resolver, and a send end connected to this.recv let recv = mem::replace(&mut self.recv, recv); Ok(Resolver { state: ResolverState::Chain { recv: ChainRecv::new(recv), send, }, }) } /// Submit a request to this `Pipeline`. Same as [`submit`][Pipeline::submit], /// but this takes `self` by move and returns it as part of the result, which /// can make it easier to construct chained futures (for instance, via /// [`.then`][FutureExt::then]). pub async fn submit_owned<T>(mut self, item: T) -> (Self, Result<Resolver<St>, Si::Error>) where Si: Sink<T>, { let res = self.submit(item).await; (self, res) } } impl<Si: Unpin, St> Pipeline<Si, St> { /// Flush the underlying `Sink`, blocking until it's finished. Note that, /// depending on your request/response system, you may also need to be sure /// that any incomplete `Resolvers` are also being awaited so that the /// responses can be drained; this method only handles flushing the /// requests side. pub async fn flush<'a, T>(&mut self) -> Result<(), Si::Error> where Si: Sink<T>, { future::poll_fn(move |cx| self.sink.poll_flush_unpin(cx)).await } } impl<Si, St: Stream> Pipeline<Si, St> { /// Finish the pipeline. Wait for all the Resolvers to complete (or abort), /// then return the original sink & stream. If the stream completed during /// the resolvers, return None instead of the stream. /// /// This function returns a `Skip<St>` so that any responses associated /// with aborted Resolvers will be skipped. pub async fn finish(self) -> (Si, Option<Skip<St>>) { let recv = ChainRecv::new(self.recv); ( self.sink, recv.await.map(|(stream, skip)| stream.skip(skip)), ) } }