datafusion_postgres/hooks/
cursor.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use datafusion::common::ParamValues;
5use datafusion::logical_expr::LogicalPlan;
6use datafusion::prelude::SessionContext;
7use datafusion::sql::sqlparser;
8use datafusion::sql::sqlparser::ast::{CloseCursor, DeclareType, FetchDirection};
9use pgwire::api::ClientInfo;
10use pgwire::api::portal::{Format, Portal};
11use pgwire::api::results::{Response, Tag};
12use pgwire::api::stmt::StoredStatement;
13use pgwire::api::store::{MemPortalStore, PortalStore};
14use pgwire::error::{PgWireError, PgWireResult};
15
16use super::{HookClient, QueryHook};
17use crate::arrow_pg::datatypes::df;
18
19pub(crate) type DfStatement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
20
21#[derive(Debug)]
23pub struct CursorStatementHook;
24
25#[async_trait]
26impl QueryHook for CursorStatementHook {
27 async fn handle_simple_query(
28 &self,
29 statement: &sqlparser::ast::Statement,
30 session_context: &SessionContext,
31 client: &mut dyn HookClient,
32 ) -> Option<PgWireResult<Response>> {
33 let store = client.portal_store();
34
35 match statement {
36 sqlparser::ast::Statement::Declare { stmts } => {
37 Some(handle_declare(store, stmts, session_context).await)
38 }
39 sqlparser::ast::Statement::Fetch {
40 name, direction, ..
41 } => Some(handle_fetch(store, name, direction).await),
42 sqlparser::ast::Statement::Close { cursor } => Some(handle_close(store, cursor)),
43 _ => None,
44 }
45 }
46
47 async fn handle_extended_parse_query(
48 &self,
49 statement: &sqlparser::ast::Statement,
50 _session_context: &SessionContext,
51 _client: &(dyn ClientInfo + Send + Sync),
52 ) -> Option<PgWireResult<LogicalPlan>> {
53 match statement {
54 sqlparser::ast::Statement::Declare { .. }
55 | sqlparser::ast::Statement::Fetch { .. }
56 | sqlparser::ast::Statement::Close { .. } => Some(Ok(LogicalPlan::EmptyRelation(
57 datafusion::logical_expr::EmptyRelation {
58 produce_one_row: false,
59 schema: Arc::new(datafusion::common::DFSchema::empty()),
60 },
61 ))),
62 _ => None,
63 }
64 }
65
66 async fn handle_extended_query(
67 &self,
68 statement: &sqlparser::ast::Statement,
69 _logical_plan: &LogicalPlan,
70 _params: &ParamValues,
71 session_context: &SessionContext,
72 client: &mut dyn HookClient,
73 ) -> Option<PgWireResult<Response>> {
74 let store = client.portal_store();
75
76 match statement {
77 sqlparser::ast::Statement::Declare { stmts } => {
78 Some(handle_declare(store, stmts, session_context).await)
79 }
80 sqlparser::ast::Statement::Fetch {
81 name, direction, ..
82 } => Some(handle_fetch(store, name, direction).await),
83 sqlparser::ast::Statement::Close { cursor } => Some(handle_close(store, cursor)),
84 _ => None,
85 }
86 }
87}
88
89async fn handle_declare(
90 store: &MemPortalStore<DfStatement>,
91 stmts: &[datafusion::sql::sqlparser::ast::Declare],
92 session_context: &SessionContext,
93) -> PgWireResult<Response> {
94 for declare in stmts {
95 if declare.declare_type != Some(DeclareType::Cursor) {
96 return Err(PgWireError::UserError(Box::new(
97 pgwire::error::ErrorInfo::new(
98 "ERROR".to_string(),
99 "42601".to_string(),
100 format!("unsupported DECLARE type: {:?}", declare.declare_type),
101 ),
102 )));
103 }
104
105 let cursor_name = match declare.names.first() {
106 Some(name) => name.value.clone(),
107 None => {
108 return Err(PgWireError::UserError(Box::new(
109 pgwire::error::ErrorInfo::new(
110 "ERROR".to_string(),
111 "42601".to_string(),
112 "cursor name is required".to_string(),
113 ),
114 )));
115 }
116 };
117
118 let for_query = match &declare.for_query {
119 Some(q) => q.to_string(),
120 None => {
121 return Err(PgWireError::UserError(Box::new(
122 pgwire::error::ErrorInfo::new(
123 "ERROR".to_string(),
124 "42601".to_string(),
125 "DECLARE CURSOR requires a FOR query".to_string(),
126 ),
127 )));
128 }
129 };
130
131 let df = session_context
132 .sql(&for_query)
133 .await
134 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
135
136 let query_response = df::encode_dataframe(df, &Format::UnifiedText, None).await?;
137
138 let stored_stmt = Arc::new(StoredStatement::new(
139 cursor_name.clone(),
140 (for_query, None),
141 vec![],
142 ));
143
144 let portal = Portal::new_cursor(cursor_name.clone(), stored_stmt);
145
146 portal.start(query_response).await;
147
148 store.put_portal(Arc::new(portal));
149 }
150
151 Ok(Response::Execution(Tag::new("DECLARE CURSOR")))
152}
153
154async fn handle_fetch(
155 store: &MemPortalStore<DfStatement>,
156 name: &datafusion::sql::sqlparser::ast::Ident,
157 direction: &FetchDirection,
158) -> PgWireResult<Response> {
159 let cursor_name = &name.value;
160
161 let max_rows = match direction {
162 FetchDirection::Next | FetchDirection::Forward { limit: None } => Some(1),
163 FetchDirection::Forward { limit: Some(v) } | FetchDirection::Count { limit: v } => {
164 parse_value_as_usize(v)
165 }
166 FetchDirection::ForwardAll | FetchDirection::All => None,
167 FetchDirection::Prior | FetchDirection::Backward { .. } | FetchDirection::BackwardAll => {
168 return Err(PgWireError::UserError(Box::new(
169 pgwire::error::ErrorInfo::new(
170 "ERROR".to_string(),
171 "42000".to_string(),
172 "cursor can only scan forward".to_string(),
173 ),
174 )));
175 }
176 FetchDirection::First
177 | FetchDirection::Last
178 | FetchDirection::Absolute { .. }
179 | FetchDirection::Relative { .. } => {
180 return Err(PgWireError::UserError(Box::new(
181 pgwire::error::ErrorInfo::new(
182 "ERROR".to_string(),
183 "42000".to_string(),
184 "cursor can only scan forward".to_string(),
185 ),
186 )));
187 }
188 };
189
190 let portal = store.get_portal(cursor_name).ok_or_else(|| {
191 PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
192 "ERROR".to_string(),
193 "34000".to_string(),
194 format!("cursor \"{cursor_name}\" does not exist"),
195 )))
196 })?;
197
198 let fetch_result = portal.fetch(max_rows.unwrap_or(0)).await?;
199
200 Ok(Response::Query(fetch_result.response))
201}
202
203fn handle_close(
204 store: &MemPortalStore<DfStatement>,
205 cursor: &CloseCursor,
206) -> PgWireResult<Response> {
207 match cursor {
208 CloseCursor::All => {
209 store.clear_portals();
210 }
211 CloseCursor::Specific { name } => {
212 store.rm_portal(&name.value);
213 }
214 }
215 Ok(Response::Execution(Tag::new("CLOSE CURSOR")))
216}
217
218fn parse_value_as_usize(value: &datafusion::sql::sqlparser::ast::Value) -> Option<usize> {
219 match value {
220 datafusion::sql::sqlparser::ast::Value::Number(s, _) => s.parse().ok(),
221 _ => None,
222 }
223}