1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#![feature(plugin_registrar, rustc_private, libc)]

extern crate libc;
extern crate rustc;
extern crate rustc_plugin;
extern crate syntax;

use rustc_plugin::Registry;
use std::ffi::{CStr, CString};
use std::mem;
use std::str;
use syntax::ast::{ExprKind, LitKind, Expr, Ident};
use syntax::codemap::Span;
use syntax::ext::base::{ExtCtxt, MacResult, MacEager, DummyResult};
use syntax::ext::build::AstBuilder;
use syntax::fold::Folder;
use syntax::symbol::InternedString;
use syntax::symbol::Symbol;
use syntax::parse::token::{Comma, Eof};
use syntax::parse;
use syntax::parse::parser::Parser;
use syntax::ptr::P;
use syntax::tokenstream::TokenTree;

mod ffi {
    use libc::{c_char, c_int};

    #[repr(C)]
    pub struct ParseResult {
        pub success: c_int,
        pub error_message: *const c_char,
        pub index: c_int,
        pub num_params: c_int,
    }

    extern {
        pub fn init_parser();
        pub fn parse_query(query: *const c_char, result: *mut ParseResult);
    }
}

struct ParseInfo {
    num_params: Option<usize>,
}

struct ParseError {
    message: String,
    index: usize,
}

#[plugin_registrar]
#[doc(hidden)]
pub fn registrar(reg: &mut Registry) {
    reg.register_macro("sql", expand_sql);
    reg.register_macro("execute", expand_execute);
    unsafe { ffi::init_parser() }
}

fn expand_sql(cx: &mut ExtCtxt, sp: Span, tts: &[TokenTree])
              -> Box<MacResult+'static> {
    let mut parser = parse::new_parser_from_tts(cx.parse_sess(), tts.to_vec());

    let query_expr = cx.expander().fold_expr(parser.parse_expr().unwrap());
    let query = match parse_str_lit(cx, &*query_expr) {
        Some(query) => query,
        None => return DummyResult::expr(sp)
    };

    match parse(&query) {
        Ok(_) => {}
        Err(err) => parse_error(cx, query_expr.span, err),
    }

    MacEager::expr(query_expr)
}

fn expand_execute(cx: &mut ExtCtxt, sp: Span, tts: &[TokenTree])
                  -> Box<MacResult+'static> {
    let mut parser = parse::new_parser_from_tts(cx.parse_sess(), tts.to_vec());

    let conn = parser.parse_expr().unwrap();

    if !parser.eat(&Comma) {
        cx.span_err(parser.span, "expected `,`");
        return DummyResult::expr(sp);
    }

    let query_expr = cx.expander().fold_expr(parser.parse_expr().unwrap());
    let query = match parse_str_lit(cx, &*query_expr) {
        Some(query) => query,
        None => return DummyResult::expr(sp),
    };

    if parser.token != Eof && !parser.eat(&Comma) {
        cx.span_err(parser.span, "expected `,`");
        return DummyResult::expr(sp);
    }

    let args = match parse_args(cx, &mut parser) {
        Some(args) => args,
        None => return DummyResult::expr(sp),
    };

    match parse(&query) {
        Ok(ParseInfo { num_params: None }) => {
            cx.span_warn(sp, "unable to verify the number of query parameters");
        }
        Ok(ParseInfo { num_params: Some(num_params) }) if num_params != args.len() => {
            cx.span_err(sp, &format!("Expected {} query parameters but got {}",
                                     num_params, args.len()));
        }
        Ok(_) => {}
        Err(err) => parse_error(cx, query_expr.span, err),
    }

    let ident = Ident::with_empty_ctxt(Symbol::intern("execute"));
    let args = cx.expr_vec(sp, args);
    let args = cx.expr_addr_of(sp, args);
    MacEager::expr(cx.expr_method_call(sp, conn, ident, vec![query_expr, args]))
}

fn parse_error(cx: &mut ExtCtxt, sp: Span, err: ParseError) {
    cx.span_err(sp, &format!("Invalid syntax at position {}: {}", err.index, err.message));
}

fn parse_str_lit(cx: &mut ExtCtxt, e: &Expr) -> Option<InternedString> {
    match e.node {
        ExprKind::Lit(ref lit) => {
            match lit.node {
                LitKind::Str(ref s, _) => Some(s.as_str()),
                _ => {
                    cx.span_err(e.span, "expected string literal");
                    None
                }
            }
        }
        _ => {
            cx.span_err(e.span, "expected string literal");
            None
        }
    }
}

fn parse_args(cx: &mut ExtCtxt, parser: &mut Parser) -> Option<Vec<P<Expr>>> {
    let mut args = Vec::new();

    while parser.token != Eof {
        args.push(parser.parse_expr().unwrap());

        if !parser.eat(&Comma) && parser.token != Eof {
            cx.span_err(parser.span, "expected `,`");
            return None;
        }
    }

    Some(args)
}

fn parse(query: &str) -> Result<ParseInfo, ParseError> {
    unsafe {
        let mut result = mem::uninitialized();
        let query = CString::new(query.as_bytes()).unwrap();
        ffi::parse_query(query.as_ptr(), &mut result);
        if result.success != 0 {
            let num_params = if result.num_params < 0 {
                None
            } else {
                Some(result.num_params as usize)
            };
            Ok(ParseInfo {
                num_params: num_params,
            })
        } else {
            let bytes = CStr::from_ptr(result.error_message).to_bytes();
            Err(ParseError {
                message: str::from_utf8(bytes).unwrap().to_string(),
                index: result.index as usize,
            })
        }
    }
}