Skip to main content

datafusion_postgres/hooks/
cursor.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use datafusion::common::ParamValues;
5use datafusion::logical_expr::LogicalPlan;
6use datafusion::prelude::SessionContext;
7use datafusion::sql::sqlparser;
8use datafusion::sql::sqlparser::ast::{CloseCursor, DeclareType, FetchDirection};
9use pgwire::api::ClientInfo;
10use pgwire::api::portal::{Format, Portal};
11use pgwire::api::results::{Response, Tag};
12use pgwire::api::stmt::StoredStatement;
13use pgwire::api::store::{MemPortalStore, PortalStore};
14use pgwire::error::{PgWireError, PgWireResult};
15
16use super::{HookClient, QueryHook};
17use crate::arrow_pg::datatypes::df;
18
19pub(crate) type DfStatement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
20
21/// Hook for processing cursor-related statements (DECLARE/FETCH/CLOSE)
22#[derive(Debug)]
23pub struct CursorStatementHook;
24
25#[async_trait]
26impl QueryHook for CursorStatementHook {
27    async fn handle_simple_query(
28        &self,
29        statement: &sqlparser::ast::Statement,
30        session_context: &SessionContext,
31        client: &mut dyn HookClient,
32    ) -> Option<PgWireResult<Response>> {
33        let store = client.portal_store();
34
35        match statement {
36            sqlparser::ast::Statement::Declare { stmts } => {
37                Some(handle_declare(store, stmts, session_context).await)
38            }
39            sqlparser::ast::Statement::Fetch {
40                name, direction, ..
41            } => Some(handle_fetch(store, name, direction).await),
42            sqlparser::ast::Statement::Close { cursor } => Some(handle_close(store, cursor)),
43            _ => None,
44        }
45    }
46
47    async fn handle_extended_parse_query(
48        &self,
49        statement: &sqlparser::ast::Statement,
50        _session_context: &SessionContext,
51        _client: &(dyn ClientInfo + Send + Sync),
52    ) -> Option<PgWireResult<LogicalPlan>> {
53        match statement {
54            sqlparser::ast::Statement::Declare { .. }
55            | sqlparser::ast::Statement::Fetch { .. }
56            | sqlparser::ast::Statement::Close { .. } => Some(Ok(LogicalPlan::EmptyRelation(
57                datafusion::logical_expr::EmptyRelation {
58                    produce_one_row: false,
59                    schema: Arc::new(datafusion::common::DFSchema::empty()),
60                },
61            ))),
62            _ => None,
63        }
64    }
65
66    async fn handle_extended_query(
67        &self,
68        statement: &sqlparser::ast::Statement,
69        _logical_plan: &LogicalPlan,
70        _params: &ParamValues,
71        session_context: &SessionContext,
72        client: &mut dyn HookClient,
73    ) -> Option<PgWireResult<Response>> {
74        let store = client.portal_store();
75
76        match statement {
77            sqlparser::ast::Statement::Declare { stmts } => {
78                Some(handle_declare(store, stmts, session_context).await)
79            }
80            sqlparser::ast::Statement::Fetch {
81                name, direction, ..
82            } => Some(handle_fetch(store, name, direction).await),
83            sqlparser::ast::Statement::Close { cursor } => Some(handle_close(store, cursor)),
84            _ => None,
85        }
86    }
87}
88
89async fn handle_declare(
90    store: &MemPortalStore<DfStatement>,
91    stmts: &[datafusion::sql::sqlparser::ast::Declare],
92    session_context: &SessionContext,
93) -> PgWireResult<Response> {
94    for declare in stmts {
95        if declare.declare_type != Some(DeclareType::Cursor) {
96            return Err(PgWireError::UserError(Box::new(
97                pgwire::error::ErrorInfo::new(
98                    "ERROR".to_string(),
99                    "42601".to_string(),
100                    format!("unsupported DECLARE type: {:?}", declare.declare_type),
101                ),
102            )));
103        }
104
105        let cursor_name = match declare.names.first() {
106            Some(name) => name.value.clone(),
107            None => {
108                return Err(PgWireError::UserError(Box::new(
109                    pgwire::error::ErrorInfo::new(
110                        "ERROR".to_string(),
111                        "42601".to_string(),
112                        "cursor name is required".to_string(),
113                    ),
114                )));
115            }
116        };
117
118        let for_query = match &declare.for_query {
119            Some(q) => q.to_string(),
120            None => {
121                return Err(PgWireError::UserError(Box::new(
122                    pgwire::error::ErrorInfo::new(
123                        "ERROR".to_string(),
124                        "42601".to_string(),
125                        "DECLARE CURSOR requires a FOR query".to_string(),
126                    ),
127                )));
128            }
129        };
130
131        let df = session_context
132            .sql(&for_query)
133            .await
134            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
135
136        let query_response = df::encode_dataframe(df, &Format::UnifiedText, None).await?;
137
138        let stored_stmt = Arc::new(StoredStatement::new(
139            cursor_name.clone(),
140            (for_query, None),
141            vec![],
142        ));
143
144        let portal = Portal::new_cursor(cursor_name.clone(), stored_stmt);
145
146        portal.start(query_response).await;
147
148        store.put_portal(Arc::new(portal));
149    }
150
151    Ok(Response::Execution(Tag::new("DECLARE CURSOR")))
152}
153
154async fn handle_fetch(
155    store: &MemPortalStore<DfStatement>,
156    name: &datafusion::sql::sqlparser::ast::Ident,
157    direction: &FetchDirection,
158) -> PgWireResult<Response> {
159    let cursor_name = &name.value;
160
161    let max_rows = match direction {
162        FetchDirection::Next | FetchDirection::Forward { limit: None } => Some(1),
163        FetchDirection::Forward { limit: Some(v) } | FetchDirection::Count { limit: v } => {
164            parse_value_as_usize(v)
165        }
166        FetchDirection::ForwardAll | FetchDirection::All => None,
167        FetchDirection::Prior | FetchDirection::Backward { .. } | FetchDirection::BackwardAll => {
168            return Err(PgWireError::UserError(Box::new(
169                pgwire::error::ErrorInfo::new(
170                    "ERROR".to_string(),
171                    "42000".to_string(),
172                    "cursor can only scan forward".to_string(),
173                ),
174            )));
175        }
176        FetchDirection::First
177        | FetchDirection::Last
178        | FetchDirection::Absolute { .. }
179        | FetchDirection::Relative { .. } => {
180            return Err(PgWireError::UserError(Box::new(
181                pgwire::error::ErrorInfo::new(
182                    "ERROR".to_string(),
183                    "42000".to_string(),
184                    "cursor can only scan forward".to_string(),
185                ),
186            )));
187        }
188    };
189
190    let portal = store.get_portal(cursor_name).ok_or_else(|| {
191        PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
192            "ERROR".to_string(),
193            "34000".to_string(),
194            format!("cursor \"{cursor_name}\" does not exist"),
195        )))
196    })?;
197
198    let fetch_result = portal.fetch(max_rows.unwrap_or(0)).await?;
199
200    Ok(Response::Query(fetch_result.response))
201}
202
203fn handle_close(
204    store: &MemPortalStore<DfStatement>,
205    cursor: &CloseCursor,
206) -> PgWireResult<Response> {
207    match cursor {
208        CloseCursor::All => {
209            store.clear_portals();
210        }
211        CloseCursor::Specific { name } => {
212            store.rm_portal(&name.value);
213        }
214    }
215    Ok(Response::Execution(Tag::new("CLOSE CURSOR")))
216}
217
218fn parse_value_as_usize(value: &datafusion::sql::sqlparser::ast::Value) -> Option<usize> {
219    match value {
220        datafusion::sql::sqlparser::ast::Value::Number(s, _) => s.parse().ok(),
221        _ => None,
222    }
223}