pilota_thrift_parser/parser/
thrift.rs

1use std::path::PathBuf;
2
3use ariadne::{Color, Label, Report, ReportKind, Source};
4use chumsky::prelude::*;
5use faststr::FastStr;
6
7use super::super::{descriptor::File, parser::*};
8use crate::{
9    Constant, CppInclude, Enum, Exception, Include, Item, Namespace, Service, Struct, Typedef,
10    Union,
11};
12
13impl Item {
14    pub fn parse<'a>() -> impl Parser<'a, &'a str, Item, extra::Err<Rich<'a, char>>> {
15        choice((
16            Include::get_parser().map(Item::Include),
17            CppInclude::parse().map(Item::CppInclude),
18            Namespace::get_parser().map(Item::Namespace),
19            Typedef::get_parser().map(Item::Typedef),
20            Constant::get_parser().map(Item::Constant),
21            Enum::get_parser().map(Item::Enum),
22            Struct::get_parser().map(Item::Struct),
23            Union::parse().map(Item::Union),
24            Exception::parse().map(Item::Exception),
25            Service::get_parser().map(Item::Service),
26        ))
27    }
28}
29
30pub struct FileSource<'a> {
31    path: Option<PathBuf>,
32    content: &'a str,
33}
34
35impl<'a> FileSource<'a> {
36    pub fn new(inline: &'a str) -> Self {
37        Self {
38            path: None,
39            content: inline,
40        }
41    }
42
43    pub fn new_with_path(path: PathBuf, content: &'a str) -> Result<Self, error::Error> {
44        if !path.exists() {
45            return Err(error::Error::FileNotFound(path));
46        }
47
48        Ok(Self {
49            path: Some(path),
50            content,
51        })
52    }
53}
54
55pub struct FileParser<'a> {
56    pub source: FileSource<'a>,
57}
58
59impl<'a> FileParser<'a> {
60    pub fn new(source: FileSource<'a>) -> Self {
61        Self { source }
62    }
63
64    pub fn parse(&self) -> Result<File, error::Error> {
65        let (ast, errs) = File::get_parser()
66            .parse(self.source.content)
67            .into_output_errors();
68
69        let path_str = match &self.source.path {
70            Some(path) => &path.display().to_string(),
71            None => "inline",
72        };
73
74        if !errs.is_empty() {
75            let mut report_strings = Vec::with_capacity(errs.len() + 1);
76
77            let title = if errs.len() == 1 {
78                format!("Failed to parse thrift file: {}", path_str)
79            } else {
80                format!(
81                    "Failed to parse thrift file: {} ({} errors found)",
82                    path_str,
83                    errs.len()
84                )
85            };
86            report_strings.push(title);
87            report_strings.push(String::new());
88
89            for (i, e) in errs.iter().enumerate() {
90                if errs.len() > 1 {
91                    let error_header = format!("Error {}:", i + 1);
92                    report_strings.push(error_header.clone());
93                }
94
95                let mut buffer = Vec::new();
96                Report::build(ReportKind::Error, (path_str, e.span().into_range()))
97                    .with_config(ariadne::Config::new().with_index_type(ariadne::IndexType::Byte))
98                    .with_message(e.to_string())
99                    .with_label(
100                        Label::new((path_str, e.span().into_range()))
101                            .with_message(e.reason().to_string())
102                            .with_color(Color::Red),
103                    )
104                    .finish()
105                    .write((path_str, Source::from(self.source.content)), &mut buffer)
106                    .unwrap();
107                report_strings.push(String::from_utf8_lossy(&buffer).to_string());
108
109                if i < errs.len() - 1 {
110                    report_strings.push(String::new());
111                }
112            }
113
114            let report = report_strings.join("\n").into();
115            let summary = create_error_summary(&errs, path_str, self.source.content).into();
116            let custom_error = CustomSyntaxError { report };
117
118            return Err(error::Error::Syntax {
119                summary,
120                source: anyhow::anyhow!(custom_error),
121            });
122        }
123
124        Ok(ast.unwrap())
125    }
126}
127
128fn create_error_summary(errs: &[chumsky::error::Rich<char>], path_str: &str, text: &str) -> String {
129    if errs.is_empty() {
130        return String::new();
131    }
132
133    let mut summary = format!("Failed to parse thrift file: {}", path_str);
134
135    if errs.len() == 1 {
136        let err = &errs[0];
137        // 计算行号和列号
138        let (line, col) = calculate_line_col(err.span().start, text);
139        summary.push_str(&format!(" at line {}:{} - {}", line, col, err.reason()));
140    } else {
141        summary.push_str(&format!(" ({} errors found):", errs.len()));
142        for (i, err) in errs.iter().enumerate() {
143            let (line, col) = calculate_line_col(err.span().start, text);
144            summary.push_str(&format!(
145                "\n  {}. Line {}:{} - {}",
146                i + 1,
147                line,
148                col,
149                err.reason()
150            ));
151        }
152    }
153
154    summary
155}
156
157fn calculate_line_col(pos: usize, text: &str) -> (usize, usize) {
158    let mut line = 1;
159    let mut col = 1;
160
161    for (i, ch) in text.char_indices() {
162        if i >= pos {
163            break;
164        }
165        if ch == '\n' {
166            line += 1;
167            col = 1;
168        } else {
169            col += 1;
170        }
171    }
172
173    (line, col)
174}
175
176#[derive(Debug)]
177pub struct CustomSyntaxError {
178    pub report: FastStr,
179}
180
181impl std::fmt::Display for CustomSyntaxError {
182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        write!(f, "{}", self.report)
184    }
185}
186
187impl std::error::Error for CustomSyntaxError {}
188
189impl File {
190    pub(crate) fn get_parser<'a>() -> impl Parser<'a, &'a str, File, extra::Err<Rich<'a, char>>> {
191        Item::parse()
192            .repeated()
193            .collect()
194            .then(Components::comment().repeated().collect::<Vec<_>>())
195            .then_ignore(Components::blank().or_not())
196            .then_ignore(end())
197            .map(|(items, c): (Vec<Item>, Vec<FastStr>)| {
198                let mut comments = String::default();
199                for item in &items {
200                    match item {
201                        Item::Include(i) => {
202                            comments.push('\n');
203                            comments.push_str(&i.leading_comments);
204                            comments.push('\n');
205                            comments.push_str(&i.trailing_comments);
206                        }
207                        Item::Namespace(n) => {
208                            comments.push('\n');
209                            comments.push_str(&n.leading_comments);
210                            comments.push('\n');
211                            comments.push_str(&n.trailing_comments);
212                        }
213                        _ => {}
214                    }
215                }
216                for comment in c {
217                    comments.push('\n');
218                    comments.push_str(&comment);
219                    comments.push('\n');
220                }
221
222                let mut file = File {
223                    items,
224                    comments: comments.into(),
225                    ..Default::default()
226                };
227
228                let mut namespaces = file.items.iter().filter_map(|i| match i {
229                    Item::Namespace(ns) => Some(ns),
230                    _ => None,
231                });
232
233                file.package = namespaces
234                    .clone()
235                    .find_map(|n| {
236                        if n.scope.0 == "rs" {
237                            Some(n.name.clone())
238                        } else {
239                            None
240                        }
241                    })
242                    .or_else(|| {
243                        namespaces.clone().find_map(|n| {
244                            if n.scope.0 == "go" {
245                                Some(n.name.clone())
246                            } else {
247                                None
248                            }
249                        })
250                    })
251                    .or_else(|| {
252                        namespaces.find_map(|n| {
253                            if n.scope.0 == "*" {
254                                Some(n.name.clone())
255                            } else {
256                                None
257                            }
258                        })
259                    });
260
261                file
262            })
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use ariadne::{Color, Label, Report, ReportKind, Source};
269
270    use super::*;
271
272    #[test]
273    fn test_thrift() {
274        let body = r#"
275        namespace go http
276
277        include "base.thrift"
278
279        enum Sex {
280            UNKNOWN = 0,
281            MALE = 1,
282            FEMALE = 2,
283        }
284        
285        struct ReqItem{
286            1: optional i64 id(api.js_conv = '', go.tag = 'json:"MyID" tagexpr:"$<0||$>=100"')
287            2: optional string text='hello world'
288            3: required string x
289        }
290        
291        struct BizCommonParam {
292            1: optional i64 api_version (api.query = 'api_version')
293            2: optional i32 token(api.header = 'token')
294        }
295        
296        struct BizRequest {
297            1: optional i64 v_int64(api.query = 'v_int64', api.vd = "$>0&&$<200")
298            2: optional string text(api.body = 'text')
299            3: optional i32 token(api.header = 'token')
300            4: optional map<i64, ReqItem> req_items_map (api.body='req_items_map')
301            5: optional ReqItem some(api.body = 'some')
302            6: optional list<string> req_items(api.query = 'req_items')
303            7: optional i32 api_version(api.path = 'action')
304            8: optional i64 uid(api.path = 'biz')
305            9: optional list<i64> cids(api.query = 'cids')
306            10: optional list<string> vids(api.query = 'vids')
307            255: base.Base base
308            256: optional BizCommonParam biz_common_param (agw.source='not_body_struct')
309        }
310        
311        struct RspItem{
312            1: optional i64 item_id
313            2: optional string text
314        }
315        
316        struct BizResponse {
317            1: optional string T                             (api.header= 'T') 
318            2: optional map<i64, RspItem> rsp_items           (api.body='rsp_items')
319            3: optional i32 v_enum                       (api.none = '')
320            4: optional list<RspItem> rsp_item_list            (api.body = 'rsp_item_list')
321            5: optional i32 http_code                         (api.http_code = '') 
322            6: optional list<i64> item_count (api.header = 'item_count')
323        }
324        
325        exception Exception{
326            1: i32 code (api.http_code = '') 
327            2: string msg 
328        }
329        
330        service BizService {
331            BizResponse BizMethod1(1: BizRequest req)(api.get = '/life/client/:action/:biz', api.baseurl = 'ib.snssdk.com', api.param = 'true')
332            BizResponse BizMethod2(1: BizRequest req)throws(1: Exception err)(api.post = '/life/client/:action/:biz', api.baseurl = 'ib.snssdk.com', api.param = 'true', api.serializer = 'form')
333            BizResponse BizMethod3(1: BizRequest req)(api.post = '/life/client/:action/:biz/other', api.baseurl = 'ib.snssdk.com', api.param = 'true', api.serializer = 'json')
334        }
335        "#;
336        let (file, errs) = File::get_parser().parse(body).into_output_errors();
337        println!("{file:#?}");
338        errs.into_iter().for_each(|e| {
339            Report::build(ReportKind::Error, ("test.thrift", e.span().into_range()))
340                .with_config(ariadne::Config::new().with_index_type(ariadne::IndexType::Byte))
341                .with_message(e.to_string())
342                .with_label(
343                    Label::new(("test.thrift", e.span().into_range()))
344                        .with_message(e.reason().to_string())
345                        .with_color(Color::Red),
346                )
347                .finish()
348                .print(("test.thrift", Source::from(body)))
349                .unwrap()
350        });
351    }
352
353    #[test]
354    fn test_separator() {
355        let body = r#"typedef i32 MyInt32
356typedef string MyString;
357
358struct TypedefTestStruct {
359  1: MyInt32 field_MyInt32;
360  2: MyString field_MyString;
361  3: i32 field_Int32;
362  4: string field_String;
363};
364
365typedef TypedefTestStruct MyStruct,
366
367const list<string> TEST_LIST = [
368    "hello",
369    "world",
370];
371
372service Service {
373  MyStruct testEpisode(1:MyStruct arg)
374},"#;
375        let (file, errs) = File::get_parser().parse(body).into_output_errors();
376        println!("{file:#?}");
377        errs.into_iter().for_each(|e| {
378            Report::build(ReportKind::Error, ("test.thrift", e.span().into_range()))
379                .with_config(ariadne::Config::new().with_index_type(ariadne::IndexType::Byte))
380                .with_message(e.to_string())
381                .with_label(
382                    Label::new(("test.thrift", e.span().into_range()))
383                        .with_message(e.reason().to_string())
384                        .with_color(Color::Red),
385                )
386                .finish()
387                .print(("test.thrift", Source::from(body)))
388                .unwrap()
389        });
390    }
391
392    #[test]
393    fn test_only_comment() {
394        let body = r#"
395        /*** comment test ***/
396        // comment 1
397
398        # comment 2
399        "#;
400        let (file, errs) = File::get_parser().parse(body).into_output_errors();
401        println!("{file:#?}");
402        errs.into_iter().for_each(|e| {
403            Report::build(ReportKind::Error, ("test.thrift", e.span().into_range()))
404                .with_config(ariadne::Config::new().with_index_type(ariadne::IndexType::Byte))
405                .with_message(e.to_string())
406                .with_label(
407                    Label::new(("test.thrift", e.span().into_range()))
408                        .with_message(e.reason().to_string())
409                        .with_color(Color::Red),
410                )
411                .finish()
412                .print(("test.thrift", Source::from(body)))
413                .unwrap()
414        });
415    }
416
417    #[test]
418    fn test_enum() {
419        let body = r#"
420// Status enum represents the status of an operation
421enum Status {
422    // Success status
423    SUCCESS = 0,
424    // Error status
425    ERROR = 1,
426}"#;
427        let (file, errs) = File::get_parser().parse(body).into_output_errors();
428        println!("{file:#?}");
429        errs.into_iter().for_each(|e| {
430            Report::build(ReportKind::Error, ("test.thrift", e.span().into_range()))
431                .with_config(ariadne::Config::new().with_index_type(ariadne::IndexType::Byte))
432                .with_message(e.to_string())
433                .with_label(
434                    Label::new(("test.thrift", e.span().into_range()))
435                        .with_message(e.reason().to_string())
436                        .with_color(Color::Red),
437                )
438                .finish()
439                .print(("test.thrift", Source::from(body)))
440                .unwrap()
441        });
442    }
443
444    #[test]
445    fn test_file_comments() {
446        let body = r#"namespace rs volo.rpc.example
447
448/*
449 * Item struct represents an item with id, title, content, and extra metadata
450 */
451
452// This is a comment for the Item struct
453struct Item {
454    // id of the item
455    1: required i64 id /* abc */ (go.tag="json:\"id\"")                   // id of the item
456
457    /*
458     * title of the item
459     */
460    2: required string title /*ddd*/ ,           // trailing comment test
461    // content of the item
462    3: required string content,             # trailing comment
463    // extra metadata of the item
464    10: optional map<string, string> extra, // trailing comment
465}
466
467// Status enum represents the status of an operation
468enum Status {
469    // Success status
470    SUCCESS = 0,
471    // Error status
472    ERROR = 1,
473}
474
475// GetItemRequest struct represents the request for getting an item
476struct GetItemRequest {
477    1: required i64 id,
478}
479
480// GetItemResponse struct represents the response for getting an item
481struct GetItemResponse {
482    1: required Item item,
483    2: required Status status,
484}
485
486// Test Service
487// This is a comment for the TestService
488service TestService {
489    // method to get an item
490    GetItemResponse getItem(1: GetItemRequest req),
491}
492
493// File comments test
494// Another file comment line"#;
495        let file = File::get_parser().parse(body).unwrap();
496        println!("{:?}", file.comments);
497    }
498}