use crate::Cirrus;
use crate::error::CirrusResult;
use crate::response::QueryResult;
use futures::future::BoxFuture;
use futures::stream::Stream;
use serde::de::DeserializeOwned;
use std::collections::VecDeque;
use std::pin::Pin;
use std::task::{Context, Poll};
type PageFuture<R> = BoxFuture<'static, CirrusResult<QueryResult<R>>>;
pub struct Records<R> {
client: Cirrus,
state: State<R>,
}
impl<R> std::fmt::Debug for Records<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let state_label = match &self.state {
State::Fetching(_) => "Fetching",
State::Buffered { records, next } => {
return f
.debug_struct("Records")
.field("state", &"Buffered")
.field("buffered_records", &records.len())
.field("has_next_page", &next.is_some())
.finish_non_exhaustive();
}
State::Done => "Done",
};
f.debug_struct("Records")
.field("state", &state_label)
.finish_non_exhaustive()
}
}
enum State<R> {
Fetching(PageFuture<R>),
Buffered {
records: VecDeque<R>,
next: Option<String>,
},
Done,
}
impl<R: DeserializeOwned + Send + Unpin + 'static> Records<R> {
pub(crate) fn new(client: Cirrus, initial: PageFuture<R>) -> Self {
Self {
client,
state: State::Fetching(initial),
}
}
}
impl<R: DeserializeOwned + Send + Unpin + 'static> Stream for Records<R> {
type Item = CirrusResult<R>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
match &mut this.state {
State::Fetching(fut) => match fut.as_mut().poll(cx) {
Poll::Ready(Ok(qr)) => {
this.state = State::Buffered {
records: qr.records.into(),
next: qr.next_records_url,
};
}
Poll::Ready(Err(e)) => {
this.state = State::Done;
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => return Poll::Pending,
},
State::Buffered { records, next } => {
if let Some(rec) = records.pop_front() {
return Poll::Ready(Some(Ok(rec)));
}
if let Some(next_url) = next.take() {
let client = this.client.clone();
let fut: PageFuture<R> =
Box::pin(async move { client.query_more_as::<R>(&next_url).await });
this.state = State::Fetching(fut);
} else {
this.state = State::Done;
return Poll::Ready(None);
}
}
State::Done => return Poll::Ready(None),
}
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use crate::Cirrus;
use crate::auth::StaticTokenAuth;
use futures::StreamExt;
use serde_json::{Value, json};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use wiremock::matchers::{method, path, query_param};
use wiremock::{Mock, MockServer, Request, ResponseTemplate};
fn fixture(uri: String) -> Cirrus {
let auth = Arc::new(StaticTokenAuth::new("tok", uri));
Cirrus::builder()
.auth(auth)
.retry_policy(crate::RetryPolicy::none())
.build()
.unwrap()
}
#[tokio::test]
async fn stream_drains_single_page_without_extra_fetches() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/services/data/v66.0/query"))
.and(query_param("q", "SELECT Id FROM Account"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"totalSize": 2,
"done": true,
"records": [
{"attributes": {"type": "Account"}, "Id": "001a"},
{"attributes": {"type": "Account"}, "Id": "001b"}
]
})))
.expect(1)
.mount(&server)
.await;
let sf = fixture(server.uri());
let records: Vec<Value> = sf
.query_stream("SELECT Id FROM Account")
.map(|r| r.unwrap())
.collect()
.await;
assert_eq!(records.len(), 2);
assert_eq!(records[0]["Id"], "001a");
assert_eq!(records[1]["Id"], "001b");
}
#[tokio::test]
async fn stream_walks_three_paginated_pages_in_order() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/services/data/v66.0/query"))
.and(query_param("q", "SELECT Id FROM Account"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"totalSize": 6,
"done": false,
"nextRecordsUrl": "/services/data/v66.0/query/01gAA-2",
"records": [
{"attributes": {"type": "Account"}, "Id": "001a"},
{"attributes": {"type": "Account"}, "Id": "001b"}
]
})))
.expect(1)
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/services/data/v66.0/query/01gAA-2"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"totalSize": 6,
"done": false,
"nextRecordsUrl": "/services/data/v66.0/query/01gAA-4",
"records": [
{"attributes": {"type": "Account"}, "Id": "001c"},
{"attributes": {"type": "Account"}, "Id": "001d"}
]
})))
.expect(1)
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/services/data/v66.0/query/01gAA-4"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"totalSize": 6,
"done": true,
"records": [
{"attributes": {"type": "Account"}, "Id": "001e"},
{"attributes": {"type": "Account"}, "Id": "001f"}
]
})))
.expect(1)
.mount(&server)
.await;
let sf = fixture(server.uri());
let records: Vec<Value> = sf
.query_stream("SELECT Id FROM Account")
.map(|r| r.unwrap())
.collect()
.await;
assert_eq!(records.len(), 6);
assert_eq!(records[0]["Id"], "001a");
assert_eq!(records[5]["Id"], "001f");
}
#[tokio::test]
async fn stream_surfaces_mid_iteration_error_then_terminates() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/services/data/v66.0/query"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"totalSize": 2500,
"done": false,
"nextRecordsUrl": "/services/data/v66.0/query/01gAA-2",
"records": [
{"attributes": {"type": "Account"}, "Id": "001a"}
]
})))
.expect(1)
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/services/data/v66.0/query/01gAA-2"))
.respond_with(ResponseTemplate::new(503).set_body_json(json!([{
"errorCode": "SERVER_UNAVAILABLE",
"message": "Service Unavailable"
}])))
.expect(1)
.mount(&server)
.await;
let sf = fixture(server.uri());
let mut stream = sf.query_stream("SELECT Id FROM Account");
let first = stream.next().await.unwrap().unwrap();
assert_eq!(first["Id"], "001a");
let err = stream.next().await.unwrap().unwrap_err();
assert!(matches!(err, crate::CirrusError::Api { status: 503, .. }));
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn dropping_stream_early_does_not_fetch_unconsumed_pages() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/services/data/v66.0/query"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"totalSize": 100,
"done": false,
"nextRecordsUrl": "/services/data/v66.0/query/01gAA-2",
"records": [
{"attributes": {"type": "Account"}, "Id": "001a"},
{"attributes": {"type": "Account"}, "Id": "001b"}
]
})))
.expect(1)
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/services/data/v66.0/query/01gAA-2"))
.respond_with(ResponseTemplate::new(500))
.expect(0)
.mount(&server)
.await;
let sf = fixture(server.uri());
let mut stream = sf.query_stream("SELECT Id FROM Account");
let first = stream.next().await.unwrap().unwrap();
let second = stream.next().await.unwrap().unwrap();
assert_eq!(first["Id"], "001a");
assert_eq!(second["Id"], "001b");
drop(stream);
}
#[tokio::test]
async fn stream_deserializes_into_caller_type() {
#[derive(serde::Deserialize, Debug)]
struct Acct {
#[serde(rename = "Name")]
name: String,
}
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/services/data/v66.0/query"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"totalSize": 2,
"done": true,
"records": [
{"attributes": {"type": "Account"}, "Name": "Acme"},
{"attributes": {"type": "Account"}, "Name": "Globex"}
]
})))
.mount(&server)
.await;
let sf = fixture(server.uri());
let names: Vec<String> = sf
.query_stream_as::<Acct>("SELECT Name FROM Account")
.map(|r| r.unwrap().name)
.collect()
.await;
assert_eq!(names, vec!["Acme", "Globex"]);
}
#[tokio::test]
async fn query_all_stream_targets_queryall_endpoint() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/services/data/v66.0/queryAll"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"totalSize": 1,
"done": true,
"records": [
{"attributes": {"type": "Account"}, "Id": "001x", "IsDeleted": true}
]
})))
.mount(&server)
.await;
let sf = fixture(server.uri());
let records: Vec<Value> = sf
.query_all_stream("SELECT Id, IsDeleted FROM Account")
.map(|r| r.unwrap())
.collect()
.await;
assert_eq!(records.len(), 1);
assert_eq!(records[0]["IsDeleted"], true);
}
#[tokio::test]
async fn tooling_query_stream_walks_tooling_prefixed_locators() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/services/data/v66.0/tooling/query"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"totalSize": 4,
"done": false,
"nextRecordsUrl": "/services/data/v66.0/tooling/query/01gAA-2",
"records": [
{"attributes": {"type": "ApexClass"}, "Id": "01p1"},
{"attributes": {"type": "ApexClass"}, "Id": "01p2"}
]
})))
.expect(1)
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/services/data/v66.0/tooling/query/01gAA-2"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"totalSize": 4,
"done": true,
"records": [
{"attributes": {"type": "ApexClass"}, "Id": "01p3"},
{"attributes": {"type": "ApexClass"}, "Id": "01p4"}
]
})))
.expect(1)
.mount(&server)
.await;
let sf = fixture(server.uri());
let records: Vec<Value> = sf
.tooling()
.query_stream("SELECT Id FROM ApexClass")
.map(|r| r.unwrap())
.collect()
.await;
assert_eq!(records.len(), 4);
assert_eq!(records[0]["Id"], "01p1");
assert_eq!(records[3]["Id"], "01p4");
}
#[tokio::test]
async fn stream_yields_buffered_records_before_fetching_next_page() {
let server = MockServer::start().await;
let fetch_count = Arc::new(AtomicUsize::new(0));
let counter = fetch_count.clone();
Mock::given(method("GET"))
.and(path("/services/data/v66.0/query"))
.respond_with(move |_req: &Request| {
counter.fetch_add(1, Ordering::SeqCst);
ResponseTemplate::new(200).set_body_json(json!({
"totalSize": 4,
"done": false,
"nextRecordsUrl": "/services/data/v66.0/query/01gAA-2",
"records": [
{"attributes": {"type": "Account"}, "Id": "001a"},
{"attributes": {"type": "Account"}, "Id": "001b"}
]
}))
})
.mount(&server)
.await;
let counter2 = fetch_count.clone();
Mock::given(method("GET"))
.and(path("/services/data/v66.0/query/01gAA-2"))
.respond_with(move |_req: &Request| {
counter2.fetch_add(1, Ordering::SeqCst);
ResponseTemplate::new(200).set_body_json(json!({
"totalSize": 4,
"done": true,
"records": [
{"attributes": {"type": "Account"}, "Id": "001c"},
{"attributes": {"type": "Account"}, "Id": "001d"}
]
}))
})
.mount(&server)
.await;
let sf = fixture(server.uri());
let mut stream = sf.query_stream("SELECT Id FROM Account");
let _r0 = stream.next().await.unwrap().unwrap();
assert_eq!(fetch_count.load(Ordering::SeqCst), 1);
let _r1 = stream.next().await.unwrap().unwrap();
assert_eq!(fetch_count.load(Ordering::SeqCst), 1);
let _r2 = stream.next().await.unwrap().unwrap();
assert_eq!(fetch_count.load(Ordering::SeqCst), 2);
let _r3 = stream.next().await.unwrap().unwrap();
assert_eq!(fetch_count.load(Ordering::SeqCst), 2);
assert!(stream.next().await.is_none());
}
}